In [1]:
import Utils
from Utils import read_txt
import tools
import copy
import os
from PIL import Image
import torch
import warnings
import numpy as np
from tqdm.notebook import tqdm
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from sklearn.metrics import roc_auc_score
import easydict
import time

# Fix Seed
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")

# Fix CPU Limitation
torch.set_num_threads(3)

## Data Loader

In [2]:
class multi_modal_dataset(Dataset):
    def __init__(self, classification_dir, segmentation_dir, txt_COVID, txt_NonCOVID,
                 audio_dir, select_num,
                 lateral_map=1, min_seg=0.01, transform=None):
        
        # Dataset Root Directory
        self.classification_dir = classification_dir
        self.segmentation_dir = segmentation_dir
        self.audio_dir = audio_dir
        
        # Subject
        self.txt_path = [txt_COVID, txt_NonCOVID]
        self.classes = ['CT_COVID', 'CT_NonCOVID']
        self.audio_classes = ['pos', 'neg']
        self.num_cls = len(self.classes)

        self.img_list = []
        self.segment_list = []
        self.audio_list = []
        
        self.min_seg = min_seg
        self.select_num = select_num
        
        for c in range(self.num_cls):
            # Classification List
            cls_list = [[os.path.join(self.classification_dir, self.classes[c], item), c] for item in
                        read_txt(self.txt_path[c])]
            self.img_list += cls_list
            
            # Audio List
            a_list = [[os.path.join(audio_dir, self.audio_classes[c], item), c] for item in os.listdir(os.path.join(audio_dir, self.audio_classes[c]))]
            self.audio_list += a_list
            
        # Split Image COVID & Non-COVID
        self.img_list = np.array(self.img_list)
        pos_index = (self.img_list[:,1]=='0')
        neg_index = (self.img_list[:,1]=='1')
        pos_img_list = self.img_list[:, 0][pos_index]
        neg_img_list = self.img_list[:, 0][neg_index]

        # Select Image List
        select_pos_img_list = np.random.choice(pos_img_list, self.select_num)
        select_neg_img_list = np.random.choice(neg_img_list, self.select_num)

        # Mapping Segment List
        select_pos_seg_list = []
        select_neg_seg_list = []
        for s_p_i_l in select_pos_img_list:
            select_pos_seg_list.append(os.path.join(segmentation_dir, s_p_i_l.split('/')[-2], "lateral_map" + str(lateral_map),
                              s_p_i_l.split('/')[-1].replace('.jpg', '.png')))
        for s_n_i_l in select_neg_img_list:
            select_neg_seg_list.append(os.path.join(segmentation_dir, s_n_i_l.split('/')[-2], "lateral_map" + str(lateral_map),
                              s_n_i_l.split('/')[-1].replace('.jpg', '.png')))

        # Split Audio COIVD & Non-COVID
        self.audio_list = np.array(self.audio_list)
        pos_index = (self.audio_list[:,1]=='0')
        neg_index = (self.audio_list[:,1]=='1')
        pos_audio_list = self.audio_list[:, 0][pos_index]
        neg_audio_list = self.audio_list[:, 0][neg_index]

        # Select Audio List
        select_pos_audio_list = np.random.choice(pos_audio_list, select_num)
        select_neg_audio_list = np.random.choice(neg_audio_list, select_num)
        
        # Make DICT
        self.data_list = []
        for i in range(self.num_cls):
            for j in range(self.select_num):
                if i == 0:
                    data_dict = {'img': select_pos_img_list[j],
                                 'seg': select_pos_seg_list[j],
                                 'audio': select_pos_audio_list[j],
                                 'label': i}
                else:
                    data_dict = {'img': select_neg_img_list[j],
                                 'seg': select_neg_seg_list[j],
                                 'audio': select_neg_audio_list[j],
                                 'label': i}
                self.data_list.append(data_dict)

        self.transform = transform

    def __len__(self):
        return self.select_num*2

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Original Data
        img_path = self.data_list[idx]['img']
        image = Image.open(img_path).convert('RGB')
        image = image.resize((256, 256))

        # Segmentation Data
        seg_path = self.data_list[idx]['seg']
        seg = Image.open(seg_path).convert('RGB')
        seg = seg.rotate(-90, expand=True)
        seg = seg.resize((256, 256))
        
        # Audio Data
        audio_path = self.data_list[idx]['audio']
        audio = Image.open(audio_path).convert('RGB')

        # Mask with Original Data
        # Step 1 => Segmentation Min-Max Normalization + Min Value(Hyperparameter)
        seg_np = np.array(seg)
        seg_mask = (seg_np - seg_np.min()) / (seg_np.max() - seg_np.min()) + self.min_seg
        # Clip max = 1
        seg_mask = np.clip(seg_mask, 0, 1)

        # Step 2 => Original Data with seg_mask
        image_with_mask = np.multiply(image, seg_mask)

        # Step 3 => Change Numpy Dtype => For Using Image Preprocessing
        image_with_mask = Image.fromarray(np.uint8(image_with_mask))
        
        if self.transform:
            image_with_mask = self.transform(image_with_mask)
            audio = self.transform(audio)
        
        sample = {'img': image_with_mask,
                  'audio': audio,
                  'label': int(self.data_list[idx]['label'])}
        return sample

In [3]:
# Normalization
normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
trans = transforms.Compose([
    transforms.Resize((224)),
    transforms.ToTensor(),
    normalize
])

# Train Dataset
trainset = multi_modal_dataset(classification_dir='./dataset/image/classfication/',
                               segmentation_dir='./dataset/image/classfication/Segmentation/',
                               txt_COVID='./dataset/image/classfication/Data-split/COVID/trainCT_COVID.txt',
                               txt_NonCOVID='./dataset/image/classfication/Data-split/NonCOVID/trainCT_NonCOVID.txt',
                               audio_dir = './dataset/audio/preprocess/train/',
                               select_num=500,
                               lateral_map=3, min_seg=0.8,
                               transform=trans)

# Validation Dataset
valset = multi_modal_dataset(classification_dir='./dataset/image/classfication/',
                               segmentation_dir='./dataset/image/classfication/Segmentation/',
                               txt_COVID='./dataset/image/classfication/Data-split/COVID/valCT_COVID.txt',
                               txt_NonCOVID='./dataset/image/classfication/Data-split/NonCOVID/valCT_NonCOVID.txt',
                               audio_dir = './dataset/audio/preprocess/validation/',
                               select_num=100,
                               lateral_map=3, min_seg=0.8,
                               transform=trans)

# Test Dataset
testset = multi_modal_dataset(classification_dir='./dataset/image/classfication/',
                               segmentation_dir='./dataset/image/classfication/Segmentation/',
                               txt_COVID='./dataset/image/classfication/Data-split/COVID/testCT_COVID.txt',
                               txt_NonCOVID='./dataset/image/classfication/Data-split/NonCOVID/testCT_NonCOVID.txt',
                               audio_dir = './dataset/audio/preprocess/test/',
                               select_num=100,
                               lateral_map=3, min_seg=0.8,
                               transform=trans)

# Data Loader
multi_train_loader = DataLoader(trainset, batch_size=5, drop_last=True, shuffle=True)
multi_val_loader = DataLoader(valset, batch_size=5, drop_last=True, shuffle=False)
multi_test_loader = DataLoader(testset, batch_size=5, drop_last=True, shuffle=False)

In [4]:
for batch_index, batch_samples in enumerate(multi_train_loader):
    image, audio, label = batch_samples['img'], batch_samples['audio'], batch_samples['label']
    break

### PreTrain Model Load

In [5]:
# DenseNet => Audio, CT BaseLine Model
print('Base Model Load....')
base_model = models.densenet169(pretrained=True)

# Image Model
print('Image Model Load....')
image_model = models.densenet169(pretrained=True)
image_model.classifier = nn.Sequential(nn.Linear(1664, 2), nn.Softmax(dim=1))
image_model.load_state_dict(torch.load('./model/single_modality/3_0.8.pt', map_location='cpu'))
image_model
image_model.eval()

# Audio Model
print('Audio Model Load....')
audio_model = models.densenet169(pretrained=True)
audio_model.classifier = nn.Linear(1664, 1)
audio_model = nn.Sequential(audio_model, nn.Sigmoid())
audio_model.load_state_dict(torch.load('./model/single_modality/audio.pt', map_location='cpu'))
audio_model
audio_model.eval()

print('All Load....')

Base Model Load....
Image Model Load....
Audio Model Load....
All Load....


### Image Feature Extractor List
- Output Layer
- Transition3
- Transition2

In [6]:
imgae_feature_extractor = []

for child in image_model.children():
    imgae_feature_extractor.append(nn.Sequential(child, nn.AvgPool2d((7,7))))
    imgae_feature_extractor.append(nn.Sequential(*list(child.children())[:-2], nn.AvgPool2d((7,7))))
    imgae_feature_extractor.append(nn.Sequential(*list(child.children())[:-4], nn.AvgPool2d((14,14))))
    break

### Audio Feature Extractor List
- Output Layer
- Transition3
- Transition2

In [7]:
audio_feature_extractor = []

for i, (name, module) in enumerate(audio_model[0]._modules.items()):
    audio_feature_extractor.append(nn.Sequential(child, nn.AvgPool2d((7,7))))
    audio_feature_extractor.append(nn.Sequential(*list(child.children())[:-2], nn.AvgPool2d((7,7))))
    audio_feature_extractor.append(nn.Sequential(*list(child.children())[:-4], nn.AvgPool2d((14,14))))
    break

### Feature Extractor List Test

In [8]:
out1 = audio_feature_extractor[0](audio)
out2 = audio_feature_extractor[1](audio)
out3 = audio_feature_extractor[2](audio)

In [9]:
out1.shape

torch.Size([5, 1664, 1, 1])

In [10]:
out2.shape

torch.Size([5, 640, 1, 1])

In [11]:
out3.shape

torch.Size([5, 256, 1, 1])

In [12]:
# Delete PreTrain Model on Memory
del base_model
del image_model
del audio_model

# Delete DataLoader
del multi_train_loader
del multi_val_loader
del multi_test_loader

### Config => Define Search Space

In [13]:
# Audio Feature
# Image Feature
# Activation Function

def get_possible_layer_configurations():
    def get_max_labels():
        return (3, 3, 2)

    list_conf = []
    max_labels = get_max_labels()
    
    # Audio Feature Extractor => 3
    for audio in range(max_labels[0]):
        # Image Feature Extractor => 3
        for image in range(max_labels[1]):
            # Num of Activate List => 3
            for activation in range(max_labels[2]):
                conf = [audio, image, activation]
                list_conf.append(conf)

    return list_conf

### Simple Surrogate

In [14]:
class SimpleRecurrentSurrogate(nn.Module):
    # number_input_feats => Hidden layer of the Image, Audio, Activation
    def __init__(self, num_hidden=100, number_input_feats=3, size_ebedding=100):
        super(SimpleRecurrentSurrogate, self).__init__()

        self.num_hidden = num_hidden

        # input embedding
        self.embedding = nn.Sequential(nn.Linear(number_input_feats, size_ebedding),
                                       nn.Sigmoid())
        # the LSTM
        self.lstm = nn.LSTM(size_ebedding, num_hidden)
        # The linear layer that maps from hidden state space to output space
        self.hid2val = nn.Linear(num_hidden, 1)

        self.nonlinearity = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.uniform_(-0.1, 0.1)
                m.bias.data.fill_(1.8)

    def forward(self, sequence_of_operations):
        # (seq_len, batch, input_size):

        embeds = []
        for s in sequence_of_operations:
            embeds.append(self.embedding(s))
        embeds = torch.stack(embeds, dim=0)

        lstm_out, hidden = self.lstm(embeds)

        val_space = self.hid2val(lstm_out[-1])
        val_space = self.nonlinearity(val_space)

        return val_space

    def eval_model(self, sequence_of_operations_np, device):
        # the user will give this data sample as numpy array (int) with size len_seq x input_size

        npseq = np.expand_dims(sequence_of_operations_np, 1)
        sequence_of_operations = torch.from_numpy(npseq).float().to(device)
        res = self.forward(sequence_of_operations)
        res = res.cpu().data.numpy()

        return res[0, 0]

In [15]:
class SurrogateDataloader():

    def __init__(self):
        self._dict_data = {}

    def add_datum(self, datum_conf, datum_acc):
        # data_conf is of size [seq_len, len_data]

        seq_len = len(datum_conf)
        datum_hash = datum_conf.data.tobytes()

        if seq_len in self._dict_data:

            if datum_hash in self._dict_data[seq_len]:
                # if the configuration is already stored, keep the max accuracy
                self._dict_data[seq_len][datum_hash] = (
                datum_conf, max(datum_acc, self._dict_data[seq_len][datum_hash][1]))
            else:
                self._dict_data[seq_len][datum_hash] = (datum_conf, datum_acc)
        else:
            self._dict_data[seq_len] = {datum_hash: (datum_conf, datum_acc)}

    def get_data(self, to_torch=False):
        # delivers list of numpy tensors of size [seq_len, num_layers, len_data]

        dataset_conf = list()
        dataset_acc = list()

        for len_key, data_dict in self._dict_data.items():

            conf_list = list()
            acc_list = list()

            for datum_hash, datum in data_dict.items():
                conf_list.append(datum[0])
                acc_list.append(datum[1])

            conf_list = np.transpose(np.asarray(conf_list, np.float32), (1, 0, 2))

            dataset_conf.append(np.array(conf_list, np.float32))
            dataset_acc.append(np.expand_dims(np.array(acc_list, np.float32), 1))

        if to_torch:
            for index in range(len(dataset_conf)):
                dataset_conf[index] = torch.from_numpy(dataset_conf[index])
                dataset_acc[index] = torch.from_numpy(dataset_acc[index])

        return dataset_conf, dataset_acc

    def get_k_best(self, k):

        dataset_conf = list()
        dataset_acc = list()

        for len_key, data_dict in self._dict_data.items():
            for datum_hash, datum in data_dict.items():
                dataset_conf.append(datum[0])
                dataset_acc.append(datum[1])

        dataset_acc = np.array(dataset_acc)
        top_k_idx = np.argpartition(dataset_acc, -k)[-k:]

        confs = [dataset_conf[i] for i in top_k_idx]
        accs = [dataset_acc[i] for i in top_k_idx]

        return (confs, accs, top_k_idx)

### Searchable ANN

In [16]:
# Input => Feature Extractor => Fusion Layer => Classification
class Searchable_ANN(nn.Module):
    def __init__(self, conf, audio_feature_extractor, image_feature_extractor, device):
        super(Searchable_ANN, self).__init__()

        # conf[0] => image hidden layer
        # conf[1] => audio hidden layer
        # conf[2] => activation function

        self.conf = conf
        self.device = device

        # Pre-Train Feature Extractor
        self.audio_feature_extractor = audio_feature_extractor
        self.image_feature_extractor = image_feature_extractor
        
        # Define Input Size
        input_size = [1664, 640, 256]
        
        # Defin Output Size
        self.out_size = 100
        
        self.alphas = [(input_size[conf[0]], input_size[conf[1]]) for conf in self.conf]

        # Define Fuse Layer
        self.fusion_layers = self._create_fc_layers()

        # Classification => COVID or Non-COVID
        self.central_classifier = nn.Sequential(nn.Linear(self.out_size, 1), nn.Sigmoid())

    # tensor_tuple => CSF, PET, SMRI
    def forward(self, image, audio):
        # Image Feature
        image_features = [self.image_feature_extractor[0].to(self.device)(image.to(self.device)).squeeze(), 
                          self.image_feature_extractor[1].to(self.device)(image.to(self.device)).squeeze(),
                          self.image_feature_extractor[2].to(self.device)(image.to(self.device)).squeeze()]
        
        image_features = [image_features[idx] for idx in self.conf[:, 0]]

        # Audio Feature
        audio_features = [self.audio_feature_extractor[0].to(self.device)(audio.to(self.device)).squeeze(),
                          self.audio_feature_extractor[1].to(self.device)(audio.to(self.device)).squeeze(),
                          self.audio_feature_extractor[2].to(self.device)(audio.to(self.device)).squeeze()]
        
        audio_features = [audio_features[idx] for idx in self.conf[:, 1]]

        # Fusion Feature
        for fusion_idx, conf in enumerate(self.conf):
            image_feat = image_features[fusion_idx]
            audio_feat = audio_features[fusion_idx]

            if fusion_idx == 0:
                fused = torch.cat((image_feat, audio_feat), 1)
                out = self.fusion_layers[fusion_idx](fused)

            else:
                fused = torch.cat((image_feat, audio_feat, out), 1)
                out = self.fusion_layers[fusion_idx](fused)

        # Dropout with Classification
        out = self.central_classifier(out)
        return out

    def central_params(self):
        central_parameters = [
            {'params': self.fusion_layers.parameters()},
            {'params': self.central_classifier.parameters()}
        ]

        return central_parameters

    def _create_fc_layers(self):
        fusion_layers = []

        for i, conf in enumerate(self.conf):
            in_size = sum(self.alphas[i])

            # args.inner_representation_size => ANN Output Size
            if i > 0:
                in_size += self.out_size

            out_size = self.out_size

            # Activation Function
            if conf[2] == 0:
                nl = nn.Sigmoid()
            elif conf[2] == 1:
                nl = nn.ReLU()

            op = nn.Sequential(nn.Linear(in_size, out_size), nl)
            fusion_layers.append(op)

        return nn.ModuleList(fusion_layers)

**Scheduler**

In [17]:
class LRCosineAnnealingScheduler():

    def __init__(self, eta_max, eta_min, Ti, Tmultiplier, num_batches_per_epoch):

        self.eta_min = eta_min
        self.eta_max = eta_max
        self.Ti = Ti
        self.Tcur = 0.0
        self.nbpe = num_batches_per_epoch
        self.iteration_counter = 0.0
        self.eta = eta_max
        self.Tm = Tmultiplier

    def _compute_rule(self):
        self.eta = self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + np.cos(np.pi * self.Tcur / self.Ti))
        return self.eta

    def step(self):

        self.Tcur = self.iteration_counter / self.nbpe
        self.iteration_counter = self.iteration_counter + 1.0
        eta = self._compute_rule()

        if eta <= self.eta_min + 1e-10:
            self.Tcur = 0
            self.Ti = self.Ti * self.Tm
            self.iteration_counter = 0

        return eta

    def update_optimizer(self, optimizer):
        state_dict = optimizer.state_dict()
        for param_group in state_dict['param_groups']:
            param_group['lr'] = self.eta
        optimizer.load_state_dict(state_dict)

### Simple Model Train

In [18]:
def train_ntu_track_acc(model, criteria, optimizer, scheduler, dataloaders, dataset_sizes, device=None, num_epochs=5):
    
    best_model_sd = copy.deepcopy(model.state_dict())
    best_acc = 0
    
    for epoch in range(num_epochs):
        # Each epoch has a training and validation phase
        for phase in ['train', 'dev']: 

            if phase == 'train':
                if not isinstance(scheduler, LRCosineAnnealingScheduler):
                    scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            
            # Iterate over data.
            for data in dataloaders[phase]:     
                
                # get the inputs
                image, audio, label = data['img'], data['audio'], data['label']                
                
                # device
                image = image.to(device)
                audio = audio.to(device)                
                label = label.to(device)
                
                # zero the parameter gradients
                optimizer.zero_grad()     
                
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    output = model(image, audio)
                    loss = criteria(output, label.float().reshape(-1,1))
                    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        scheduler.step()
                        scheduler.update_optimizer(optimizer)
                        loss.backward()
                        optimizer.step()
                        
                # statistics
                running_loss += loss.item() * image.size(0)
                output[output<=0.5] = 0
                output[output>0.5] = 1
                
                running_corrects += torch.sum(output.squeeze() == label.data)
        
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc  = running_corrects.double() / dataset_sizes[phase]
    
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            # deep copy the model
            if phase == 'dev' and epoch_acc > best_acc:
                best_acc = epoch_acc
                
    model.train(False)            
    
    return best_acc

In [19]:
# Modeltype => conf, audio_feature_extractor, image_feature_extractor, device
def train_sampled_models(sampled_configurations, searchable_type, dataloaders,
                         args, device,
                         audio_feature_extractor, image_feature_extractor):
    
    dataset_sizes = {x: len(dataloaders[x].dataset) for x in ['train', 'dev']}
    num_batches_per_epoch = dataset_sizes['train'] / args.batchsize
    criterion = torch.nn.BCELoss()

    real_accuracies = []

    for idx, configuration in enumerate(sampled_configurations):

        rmode = searchable_type(configuration, audio_feature_extractor, image_feature_extractor, device)
        params = rmode.parameters()

        # optimizer and scheduler
        optimizer = optim.Adam(params, lr=args.eta_max, weight_decay=1e-4)
        scheduler = LRCosineAnnealingScheduler(args.eta_max, args.eta_min, args.Ti, args.Tm,
                                                      num_batches_per_epoch)
        rmode.to(device)
        print('Now training: ')
        print(configuration)

        best_model_acc = train_ntu_track_acc(rmode, criterion, optimizer, scheduler, dataloaders,
                                                        dataset_sizes,
                                                        device=device, num_epochs=args.epochs)
        # Append Result
        real_accuracies.append(best_model_acc)

    return real_accuracies

### Model Searcher

In [20]:
class ModelSearcher():
    def __init__(self, args):
        self.args = args

    def search(self):
        pass

    def _epnas(self, model_type, surrogate_dict, dataloaders, dataset_searchmethods, device, audio_feature_extractor, image_feature_extractor):

        # surrogate
        surrogate = surrogate_dict['model']
        s_crite = surrogate_dict['criterion']
        s_data = SurrogateDataloader()
        s_optim = optim.Adam(surrogate.parameters(), lr=0.001)

        # search functions that are specific to the dataset
        train_sampled_models = dataset_searchmethods['train_sampled']
        get_possible_layer_configurations = dataset_searchmethods['get_layer_confs']

        temperature = 10.0

        sampled_k_confs = []

        shared_weights = dict()

        # repeat process search_iterations times
        for si in range(2):
            print(50 * "=")
            print("Search iteration {}/{} ".format(si, 3))

            # for each fusion
            for progression_index in range(2):

                print(25 * "-")
                print("Progressive step {}/{} ".format(progression_index, 3))

                # Step 1: unfold layer (fusion index)
                list_possible_layer_confs = get_possible_layer_configurations()

                # Step 2: merge previous top with unfolded configurations
                all_configurations = tools.merge_unfolded_with_sampled(sampled_k_confs, list_possible_layer_confs,
                                                                       progression_index)

                # Step 3: obtain accuracies for all possible unfolded configurations
                # if first execution, just train all, if not, use surrogate to predict them
                if si + progression_index == 0:
                    all_accuracies = train_sampled_models(all_configurations, model_type, dataloaders, self.args, device, audio_feature_extractor, image_feature_extractor)
                    tools.update_surrogate_dataloader(s_data, all_configurations, all_accuracies)
                    tools.train_surrogate(surrogate, s_data, s_optim, s_crite, self.args, device)

                    print("Trained architectures: ")
                    print(list(zip(all_configurations, all_accuracies)))
                    
                else:
                    all_accuracies = tools.predict_accuracies_with_surrogate(all_configurations, surrogate, device)
                    print("Predicted accuracies: ")
                    print(list(zip(all_configurations, all_accuracies)))

                # Step 4: sample K architectures and train them. 
                # this should happen only if not first iteration because in that case, 
                # all confs were trained in step 3
                if si + progression_index == 0:
                    sampled_k_confs = tools.sample_k_configurations(all_configurations, all_accuracies,
                                                                    self.args.num_samples, temperature)

                    estimated_accuracies = tools.predict_accuracies_with_surrogate(all_configurations, surrogate,
                                                                                       device)
                    diff = np.abs(np.array(estimated_accuracies) - np.array(all_accuracies))
                    print("Error on accuracies = {}".format(diff))

                else:
                    sampled_k_confs = tools.sample_k_configurations(all_configurations, all_accuracies,
                                                                    self.args.num_samples, temperature)
                    sampled_k_accs = train_sampled_models(sampled_k_confs, model_type, dataloaders, self.args, device, audio_feature_extractor, image_feature_extractor)

                    tools.update_surrogate_dataloader(s_data, sampled_k_confs, sampled_k_accs)
                    err = tools.train_surrogate(surrogate, s_data, s_optim, s_crite, self.args, device)

                    print("Trained architectures: ")
                    print(list(zip(sampled_k_confs, sampled_k_accs)))
                    print("with surrogate error: {}".format(err))

                # temperature decays at each step
                iteration = si * self.args.search_iterations + progression_index
                temperature = tools.compute_temperature(iteration, self.args)
                print("Temperature is being set to {}".format(temperature))

        return s_data

In [21]:
class Multi_Modal_Searcher(ModelSearcher):
    def __init__(self, args, trainset, valset, device, audio_feature_extractor, image_feature_extractor):
        super(Multi_Modal_Searcher, self).__init__(args)

        self.device = device

        datasets = {'train': trainset, 'dev': valset}
        
        self.dataloaders = {
            x: DataLoader(datasets[x], batch_size=args.batchsize, shuffle=True, num_workers=1,
                          drop_last=True) for x in ['train', 'dev']}
        self.audio_feature_extractor = audio_feature_extractor
        self.image_feature_extractor = image_feature_extractor

    def search(self):
        surrogate = SimpleRecurrentSurrogate(100, 3, 100)
        surrogate.to(self.device)
        surrogate_dict = {'model': surrogate, 'criterion': torch.nn.MSELoss()}
        ntu_searchmethods = {'train_sampled': train_sampled_models,
                             'get_layer_confs': get_possible_layer_configurations}

        return self._epnas(Searchable_ANN, surrogate_dict, self.dataloaders, ntu_searchmethods, self.device, self.audio_feature_extractor, self.image_feature_extractor)

### Search NAS by MFAS

**Hyperparameter**

In [22]:
args = easydict.EasyDict({ "epochs": 2, "search_iterations": 3,
                          "eta_max": 0.001, "eta_min": 0.000001, "Ti": 1, "Tm": 2,
                          "batchsize": 10, "num_samples": 10})

**MFAS**

In [23]:
ntu_searcher = Multi_Modal_Searcher(args, trainset, valset, device, audio_feature_extractor, imgae_feature_extractor)

In [None]:
print("MFAS for NTU Started!!!!")
start_time = time.time()
surrogate_data = ntu_searcher.search()
time_elapsed = time.time() - start_time
print('Search complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

MFAS for NTU Started!!!!
Search iteration 0/3 
-------------------------
Progressive step 0/3 
Now training: 
[[0 0 0]]
train Loss: 0.3103 Acc: 0.8830
dev Loss: 0.3683 Acc: 0.8550
