# download model

In [1]:
from urllib.parse import urlparse
from torch import hub
import re
import os
import torch
import requests
from torchvision import models
import glob
from natsort import natsorted
from skimage import io, segmentation, morphology, measure, exposure
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np

In [2]:
def download_model(url, dst_path):
    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    
    # HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
    # hash_prefix = HASH_REGEX.search(filename).group(1)
    
    hub.download_url_to_file(url, os.path.join(dst_path, filename), hash_prefix=None, progress=True)
    return filename


In [3]:
def download_file(url, dst_path):
    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    local_filename = os.path.join(dst_path, filename)
    with requests.get(url, stream=True) as response:
        with open(local_filename, 'wb') as file:
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    file.write(chunk)


# load package AND set paths of neuron response 

In [1]:
from urllib.parse import urlparse
from torch import hub
import re
import os
import torch
from torch import nn
import requests
import torchvision
from torchvision import models,datasets,transforms
import glob
from natsort import natsorted
from skimage import io, segmentation, morphology, measure, exposure
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
from scipy.io import loadmat, savemat
from scipy.stats import pearsonr, spearmanr, friedmanchisquare
import matplotlib.pyplot as plt
import matplotlib as mpl
import time
from PIL import Image
import sys
import json
import pandas as pd
import shutil

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def loadmat_data(filename):
    file = loadmat(filename)
    name = list(file.keys())
    data = file[name[3]]
    return data

In [3]:
MonkeyA_path = './calcium_imaging_awake_monkey/MA/Analysis'
MonkeyB_path = './calcium_imaging_awake_monkey/MB_CC/Analysis'
MonkeyC_path = './calcium_imaging_awake_monkey/MC_CC/Analysis'
MonkeyD_path = './calcium_imaging_awake_monkey/MD/Analysis'

# define Dataset transforms

In [4]:
torch_resize = transforms.Resize([256,256])
torch_crop = transforms.CenterCrop([224,224])
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# load model(pretrained=Train)

In [5]:
os.environ['CUDA_VISIBLE_DEVICES']='2'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.squeezenet1_1(pretrained=True).to(device)
model.eval()


SqueezeNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (3): Fire(
      (squeeze): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (4): Fire(
      (squeeze): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (5): MaxPool2d

# define the output from every layer in model

In [6]:
layer_outputs = {}
def hook_fn(module, input, output, layer_name):
    layer_outputs[layer_name] = output

for name, layer in model.named_modules():
    if isinstance(layer, torch.nn.ModuleList):
        continue
    hook = layer.register_forward_hook(lambda module, input, output, name=name: hook_fn(module, input, output, name))

# Find the Layer most similar to V1

## define dataset

In [8]:
class TestDatasetsSimilar(Dataset):
    def __init__(self, files) -> None:
        super().__init__()
        self.files = files
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
        
    def __len__(self):
        return len(self.files)
        
    def __getitem__(self, index):
        file = self.files[index]
 
        img_data = Image.open(file)
        img_data = self.transform(img_data)

        return img_data

## get the output from every layer in model

In [9]:
data_path = './gabor/'
img_names = natsorted(glob.glob(data_path + "*.jpg"))
test_datasets = TestDatasetsSimilar(img_names)
test_loader = DataLoader(dataset= test_datasets, batch_size=1, shuffle=False)

In [10]:
with torch.no_grad():
    img = next(iter(test_loader)).to(device)
    output = model(img)
all_layer_names = list(layer_outputs.keys())
all_layer_names

['features.0',
 'features.1',
 'features.2',
 'features.3.squeeze',
 'features.3.squeeze_activation',
 'features.3.expand1x1',
 'features.3.expand1x1_activation',
 'features.3.expand3x3',
 'features.3.expand3x3_activation',
 'features.3',
 'features.4.squeeze',
 'features.4.squeeze_activation',
 'features.4.expand1x1',
 'features.4.expand1x1_activation',
 'features.4.expand3x3',
 'features.4.expand3x3_activation',
 'features.4',
 'features.5',
 'features.6.squeeze',
 'features.6.squeeze_activation',
 'features.6.expand1x1',
 'features.6.expand1x1_activation',
 'features.6.expand3x3',
 'features.6.expand3x3_activation',
 'features.6',
 'features.7.squeeze',
 'features.7.squeeze_activation',
 'features.7.expand1x1',
 'features.7.expand1x1_activation',
 'features.7.expand3x3',
 'features.7.expand3x3_activation',
 'features.7',
 'features.8',
 'features.9.squeeze',
 'features.9.squeeze_activation',
 'features.9.expand1x1',
 'features.9.expand1x1_activation',
 'features.9.expand3x3',
 'feat

In [11]:
each_img_each_layer_outputs = {}
layer_name = [ 'features.0', 'features.1', 'features.2','features.3', 'features.4', 'features.5', 'features.6', 'features.7', 'features.8',
              'features.9', 'features.10', 'features.11', 'features.12']
ori = np.arange(15,181,15)
with torch.no_grad():
    for i, img in enumerate(test_loader):
        img = img.to(device)
        output = model(img)
        each_img_each_layer_outputs[ori[i]] = {}
        for layer in layer_name:
            each_img_each_layer_outputs[ori[i]][layer] = layer_outputs[layer].flatten()

## load the RDM of V1

In [12]:
RDM_monkey = {}
data_path = os.path.join(MonkeyA_path, './geometry/RDM_n_mean_12.mat')
RDM_ori = loadmat_data(data_path)
RDM_monkey['MonkeyA'] = RDM_ori
data_path = os.path.join(MonkeyB_path, './geometry/RDM_n_mean_12.mat')
RDM_ori = loadmat_data(data_path)
RDM_monkey['MonkeyB'] = RDM_ori
data_path = os.path.join(MonkeyC_path, './geometry/RDM_n_mean_12.mat')
RDM_ori = loadmat_data(data_path)
RDM_monkey['MonkeyC'] = RDM_ori
data_path = os.path.join(MonkeyD_path, './geometry/RDM_n_mean_12.mat')
RDM_ori = loadmat_data(data_path)
RDM_monkey['MonkeyD'] = RDM_ori


## calculate the similiarty of ANN and V1

In [13]:
RDM_similiarty_all_layers = {}
for k in range(len(layer_name)):
    target_layer = layer_name[k]
    target_value = {}
    for key, value in each_img_each_layer_outputs.items():
        if target_layer in value:
            target_value[key] = value[target_layer]
            
    oln = len(ori)
    RDmatrix = torch.zeros((oln,oln)).to(device)
    for i in range(oln):
        Yi = target_value[ori[i]]
        for j in range(oln):
            Yj = target_value[ori[j]]
            RDmatrix[i,j] = torch.norm( Yi-Yj)
    min_val = torch.min(RDmatrix)
    max_val = torch.max(RDmatrix)
    RDmatrix = ((RDmatrix - min_val) / (max_val - min_val)).cpu().numpy().flatten()

    RDM_similiarty_r = np.zeros(len(RDM_monkey))
    RDM_similiarty_p = np.zeros(len(RDM_monkey))
    for i, (key, value) in enumerate(RDM_monkey.items()):
        RDM_ori = RDM_monkey[key].flatten()
        RDM_similiarty_r[i], RDM_similiarty_p[i] = spearmanr(RDM_ori, RDmatrix)
    RDM_similiarty_all_layers[target_layer] =  RDM_similiarty_r


## find the target layer

In [15]:
average_dcit = {key: sum(value)/ len(value) for key, value in RDM_similiarty_all_layers.items()}
most_resembles_layer = max(average_dcit, key=average_dcit.get)

# Select non-orientation-tuned neurons in ANN

## define dataset

In [54]:
class TestDatasets(Dataset):
    def __init__(self, files) -> None:
        super().__init__()
        self.files = files
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
        
    def __len__(self):
        return len(self.files)
        
    def __getitem__(self, index):
        file = self.files[index]
 
        img_data = Image.open(file)
        img_data = self.transform(img_data)

        return img_data

## run

In [56]:
data_path = './gabor_select_neurons/'
img_names = natsorted(glob.glob(data_path + "*.jpg"))
test_datasets = TestDatasets(img_names)
test_loader = DataLoader(dataset= test_datasets, batch_size=1, shuffle=False)
each_ori_sample_numbers = 15
ori = np.arange(15,181,15)
oln = len(ori)

In [57]:
with torch.no_grad():
    img = next(iter(test_loader)).to(device)
    output = model(img)
target_layer_output = layer_outputs[most_resembles_layer].flatten()
ANN_response = torch.zeros((oln, each_ori_sample_numbers, len(target_layer_output))).to(device)

In [58]:
with torch.no_grad():
    for i, img in enumerate(test_loader):
        img = img.to(device)
        output = model(img)
        ori_index = i // each_ori_sample_numbers    
        sample_index = i % each_ori_sample_numbers  
        target_layer_output = layer_outputs[most_resembles_layer].flatten()
        ANN_response[ori_index, sample_index, :] = target_layer_output
    

In [59]:
ori_neurons_index = np.zeros(len(target_layer_output), dtype=bool)

for neuron in range(len(target_layer_output)):
    neuron_response = ANN_response[:, :, neuron].cpu().numpy()
    statistic, p_value = friedmanchisquare(*neuron_response)
    if p_value < 0.01:
        ori_neurons_index[neuron] = True
    

In [60]:
ANN_response_non_ori = ANN_response[:, :, ~ori_neurons_index]
ANN_response = ANN_response.cpu().numpy()
ANN_response_non_ori = ANN_response_non_ori.cpu().numpy()

In [61]:
Non_number = len(ori_neurons_index)-sum(ori_neurons_index)
Non_percent = Non_number / len(ori_neurons_index)
print(f'Non_number:{Non_number}')
print(f'Non_percent:{Non_percent}')

Non_number:29607
Non_percent:0.15864519032921812


In [62]:
savemat('./ANN_output/Squeezenet/ann_ori_neurons_index.mat', {'ori_neurons_index': ori_neurons_index})
savemat('./ANN_output/Squeezenet/ANN_response.mat', {'ANN_response': ANN_response})
savemat('./ANN_output/Squeezenet/ANN_response_non_ori.mat', {'ANN_response_non_ori': ANN_response_non_ori})

# Ablation experiment

## load functions and set dir

In [18]:
output_dir = './ANN_output/Squeezenet/'

### evaluation metrics

In [19]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f') -> None:
        self.name = name
        self.fmt = fmt
        self.reset()  
        
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
    
    def __str__(self) -> str:

        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def accuracy(output, target, topk=(1,)):
    """computes the accuracy over the k top predictions for the specified values of k
"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        
        _, pred  = output.topk(maxk, 1, True, True)
        pred = pred.t()

        correct = pred.eq(target.view(1,-1).expand_as(pred))   
        
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100/batch_size))
        return res
        

### load imagenet

In [20]:
class DefaultConfigs(object):
    # 1.string parameters
    val_dir = "/DATA/ImageNet/val"
    model_name = "Squeezenet"
    batch_size = 4
    interval = 10

config = DefaultConfigs()

In [21]:

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

val_img_dataset = datasets.ImageFolder(
    config.val_dir,
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
)

val_img_loader = DataLoader(val_img_dataset,
                        batch_size=config.batch_size,
                        shuffle=False
)
mapping = val_img_dataset.class_to_idx 

with torch.no_grad():
    image, target = next(iter(val_img_loader))

### functions

In [22]:
def calculate_class_acc(targets, preds):
    class_counts = {} 
    correct_counts = {} 
    for i in range(len(targets)):
        target = targets[i]
        pred = preds[i]
        class_counts[target] =class_counts.get(target,0) + 1
        if target == pred:
            correct_counts[target] = correct_counts.get(target,0) + 1
    class_accuracy = {}

    for target in class_counts:
        accuracy = correct_counts.get(target, 0) / class_counts[target]
        class_accuracy[target] = accuracy
    return class_counts, correct_counts, class_accuracy

def get_class_name(class_list, mapping):
    class_name_list = []
    for pair in class_list:
        class_label = pair[0]
        for key, value in mapping.items():
            if value == class_label:
                class_name = key
                break
        class_name_list.append(class_name)
    return class_name_list

def save_list(filename, save_list):
    str = '\n'
    with open(filename, 'w') as f:
        f.write(str.join(save_list))
        
def select_csv(csv_df, class_name_list):
    result_df = pd.DataFrame(columns=csv_df.columns)
    for class_name in class_name_list:
        filtered_rows = csv_df[csv_df['WNID'] == class_name]
        result_df = pd.concat([result_df, filtered_rows]).reset_index(drop=True)
    return result_df

def copy_figures(source_dir, target_dir, class_name_list):
    if os.path.exists(target_dir):
        shutil.rmtree(target_dir)
    os.makedirs(target_dir, exist_ok=False)
    for class_name in class_name_list:
        if os.path.exists(os.path.join(target_dir, class_name)):
            shutil.rmtree(os.path.join(target_dir, class_name))
        shutil.copytree(
            os.path.join(source_dir, class_name), 
            os.path.join(target_dir, class_name)
            )

## test Imagenet with full ann

### validate(all neurons)

In [24]:
sub_output_dir = os.path.join(output_dir, 'results_all_neurons')
os.makedirs(sub_output_dir, exist_ok=True)
model = models.squeezenet1_1(pretrained=True).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
model.eval()

SqueezeNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (3): Fire(
      (squeeze): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (4): Fire(
      (squeeze): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (5): MaxPool2d

In [25]:
log_file_path = os.path.join(sub_output_dir, 'acc_output_log.txt')
if os.path.exists(log_file_path):
    os.remove(log_file_path)
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ":6.3f")
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
preds_list = []
targets_list = []
with torch.no_grad():
    end = time.time()
    for batch_id, (image, target) in enumerate(val_img_loader):
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        

        acc1, acc5 = accuracy(output, target, topk=(1,5))
        losses.update(loss.item(), image.size(0))
        top1.update(acc1, image.size(0))
        top5.update(acc5, image.size(0))

        targets_list.append(target)
        _, pred  = output.topk(1, 1, True, True)
        pred = torch.squeeze(pred)
        preds_list.append(pred)

        batch_time.update(time.time() - end)
        end = time.time()
        
        if (batch_id + 1) % config.interval == 0 :
            with open(log_file_path, 'a') as f:
                print(f'Acc@1: {top1.avg.item():.3f} \t Acc@5: {top5.avg.item():.3f} \t Time: {batch_time.val:.2f} \t ID: {batch_id:d}', file=f)
    with open(log_file_path, 'a') as f:
        print(f' * Acc@1: {top1.avg.item():.3f} \t Acc@5: {top5.avg.item():.3f} ', file=f)
    targets = torch.cat(targets_list)
    preds = torch.cat(preds_list)


In [26]:
data = {'acc1':top1.avg.item(), 'acc5':top5.avg.item()}
with open(os.path.join(sub_output_dir, 'acc_average.json'),'w') as f:
    json.dump(data, f)
torch.save(targets, os.path.join(sub_output_dir, 'targets.pt'))
torch.save(preds, os.path.join(sub_output_dir, 'preds.pt'))

### load results

In [27]:
sub_output_dir = os.path.join(output_dir, 'results_all_neurons')
with open(os.path.join(sub_output_dir, 'acc_average.json'), 'r') as f:
    data = json.load(f)
targets = torch.load(os.path.join(sub_output_dir, 'targets.pt')).cpu().numpy()
preds = torch.load(os.path.join(sub_output_dir, 'preds.pt')).cpu().numpy()

## mask non-orientation neurons to test imagenet

### define mask layer and modify model

In [32]:
class MaskLayer(nn.Module):
    def __init__(self, mask) -> None:
        super(MaskLayer, self).__init__()
        self.register_buffer('mask', mask)
        
    def forward(self, x):
        x = x * self.mask
        return x

In [48]:
target_layer_output = layer_outputs[most_resembles_layer]
most_resembles_layer_index =  6
ori_neurons_index = torch.from_numpy(loadmat_data(os.path.join(output_dir, 'ann_ori_neurons_index'))).to(device)
ori_neurons_index = ori_neurons_index.bool()
ori_mask = ori_neurons_index.reshape(target_layer_output.shape)

In [49]:
model = models.squeezenet1_1(pretrained=True).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
model.eval()
features = model.features 
new_features = []
for i, layer in enumerate(features.children()):
    new_features.append(layer)
    
    if i == most_resembles_layer_index:
        new_features.append(MaskLayer(ori_mask))
        print(i)
model.features = nn.Sequential(*new_features)

6


In [51]:
model.features

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
  (1): ReLU(inplace=True)
  (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (3): Fire(
    (squeeze): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
    (squeeze_activation): ReLU(inplace=True)
    (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
    (expand1x1_activation): ReLU(inplace=True)
    (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (expand3x3_activation): ReLU(inplace=True)
  )
  (4): Fire(
    (squeeze): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
    (squeeze_activation): ReLU(inplace=True)
    (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
    (expand1x1_activation): ReLU(inplace=True)
    (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (expand3x3_activation): ReLU(inplace=True)
  )
  (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
 

### validate

In [52]:
sub_output_dir = os.path.join(output_dir, 'results_mask_non_ori_neurons')
os.makedirs(sub_output_dir, exist_ok=True)

In [53]:
log_file_path = os.path.join(sub_output_dir, 'acc_output_log.txt')
if os.path.exists(log_file_path):
    os.remove(log_file_path)
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ":6.3f")
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
preds_list = []
targets_list = []
with torch.no_grad():
    end = time.time()
    for batch_id, (image, target) in enumerate(val_img_loader):
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        
        acc1, acc5 = accuracy(output, target, topk=(1,5))
        losses.update(loss.item(), image.size(0))
        top1.update(acc1, image.size(0))
        top5.update(acc5, image.size(0))
        
        targets_list.append(target)
        _, pred  = output.topk(1, 1, True, True)
        pred = torch.squeeze(pred)
        preds_list.append(pred)
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        if (batch_id + 1) % config.interval == 0 :
            with open(log_file_path, 'a') as f:
                print(f'Acc@1: {top1.avg.item():.3f} \t Acc@5: {top5.avg.item():.3f} \t Time: {batch_time.val:.2f} \t ID: {batch_id:d}', file=f)
    with open(log_file_path, 'a') as f:
        print(f' * Acc@1: {top1.avg.item():.3f} \t Acc@5: {top5.avg.item():.3f} ', file=f)
    targets = torch.cat(targets_list)
    preds = torch.cat(preds_list)


In [54]:
data = {'acc1':top1.avg.item(), 'acc5':top5.avg.item()}
with open(os.path.join(sub_output_dir, 'acc_average.json'),'w') as f:
    json.dump(data, f)
torch.save(targets, os.path.join(sub_output_dir, 'targets.pt'))
torch.save(preds, os.path.join(sub_output_dir, 'preds.pt'))

### load results

In [None]:
sub_output_dir = os.path.join(output_dir, 'results_mask_non_ori_neurons')
with open(os.path.join(sub_output_dir, 'acc_average.json'), 'r') as f:
    data = json.load(f)
targets = torch.load(os.path.join(sub_output_dir, 'targets.pt')).cpu().numpy()
preds = torch.load(os.path.join(sub_output_dir, 'preds.pt')).cpu().numpy()


## mask orientation-tuned neurons to test imagenet

In [23]:
sub_output_dir = os.path.join(output_dir, 'results_mask_ori_neurons')
os.makedirs(sub_output_dir, exist_ok=True)

### define mask layer and modify model

In [24]:
class MaskLayer(nn.Module):
    def __init__(self, mask) -> None:
        super(MaskLayer, self).__init__()
        self.register_buffer('mask', mask)
        
    def forward(self, x):
        x = x * self.mask
        return x

In [25]:
target_layer_output = layer_outputs[most_resembles_layer]
most_resembles_layer_index =  6
ori_neurons_index = torch.from_numpy(loadmat_data(os.path.join(output_dir, 'ann_ori_neurons_index'))).to(device)
ori_neurons_index = ori_neurons_index.bool()


if torch.cuda.is_available():
    ori_neurons_index_array = ori_neurons_index.cpu().numpy()
else:
    ori_neurons_index_array = ori_neurons_index.numpy()
non_ori_indices = np.where(ori_neurons_index_array==False)[1]
neuron_number = ori_neurons_index_array.shape[1]
Ori_number = sum(sum(ori_neurons_index_array))
Non_number = neuron_number - Ori_number
available_indices = list(set(range(neuron_number)) - set(non_ori_indices))
np.random.seed(4)
if Non_number < Ori_number:
    random_indices = np.random.choice(available_indices, size=Non_number, replace=False)
else:
    random_indices = np.random.choice(available_indices, size=Ori_number, replace=False)


mask_tmp = torch.ones(ori_neurons_index_array.shape, dtype=bool)
mask_tmp[0, random_indices] = False
ori_mask = mask_tmp.reshape(target_layer_output.shape).to(device)

In [29]:
model = models.squeezenet1_1(pretrained=True).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
model.eval()
features = model.features 
new_features = []
for i, layer in enumerate(features.children()):
    new_features.append(layer)
    
    if i == most_resembles_layer_index:
        new_features.append(MaskLayer(ori_mask))
        print(i)
model.features = nn.Sequential(*new_features)

6


In [30]:
model.features

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
  (1): ReLU(inplace=True)
  (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (3): Fire(
    (squeeze): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
    (squeeze_activation): ReLU(inplace=True)
    (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
    (expand1x1_activation): ReLU(inplace=True)
    (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (expand3x3_activation): ReLU(inplace=True)
  )
  (4): Fire(
    (squeeze): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
    (squeeze_activation): ReLU(inplace=True)
    (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
    (expand1x1_activation): ReLU(inplace=True)
    (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (expand3x3_activation): ReLU(inplace=True)
  )
  (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
 

### validate

In [31]:
log_file_path = os.path.join(sub_output_dir, 'acc_output_log.txt')
if os.path.exists(log_file_path):
    os.remove(log_file_path)
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ":6.3f")
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
preds_list = []
targets_list = []
with torch.no_grad():
    end = time.time()
    for batch_id, (image, target) in enumerate(val_img_loader):
        image, target = image.to(device), target.to(device)
        output = modified_model(image, mask = ori_mask)
        loss = criterion(output, target)
        
        acc1, acc5 = accuracy(output, target, topk=(1,5))
        losses.update(loss.item(), image.size(0))
        top1.update(acc1, image.size(0))
        top5.update(acc5, image.size(0))
        
        targets_list.append(target)
        _, pred  = output.topk(1, 1, True, True)
        pred = torch.squeeze(pred)
        preds_list.append(pred)

        batch_time.update(time.time() - end)
        end = time.time()
        
        if (batch_id + 1) % config.interval == 0 :
            with open(log_file_path, 'a') as f:
                print(f'Acc@1: {top1.avg.item():.3f} \t Acc@5: {top5.avg.item():.3f} \t Time: {batch_time.val:.2f} \t ID: {batch_id:d}', file=f)
    with open(log_file_path, 'a') as f:
        print(f' * Acc@1: {top1.avg.item():.3f} \t Acc@5: {top5.avg.item():.3f} ', file=f)
    targets = torch.cat(targets_list)
    preds = torch.cat(preds_list)


In [32]:
data = {'acc1':top1.avg.item(), 'acc5':top5.avg.item()}
with open(os.path.join(sub_output_dir, 'acc_average.json'),'w') as f:
    json.dump(data, f)
torch.save(targets, os.path.join(sub_output_dir, 'targets.pt'))
torch.save(preds, os.path.join(sub_output_dir, 'preds.pt'))

In [33]:
data

{'acc1': 57.02799987792969, 'acc5': 79.51599884033203}

### load results

In [None]:
sub_output_dir = os.path.join(output_dir, 'results_mask_ori_neurons')
with open(os.path.join(sub_output_dir, 'acc_average.json'), 'r') as f:
    data = json.load(f)
targets = torch.load(os.path.join(sub_output_dir, 'targets.pt')).cpu().numpy()
preds = torch.load(os.path.join(sub_output_dir, 'preds.pt')).cpu().numpy()
