In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MetricHead(nn.Module):
    def __init__(self, in_channels):
        super(MetricHead, self).__init__()
        
        # Riduzione a 2 canali con una convoluzione
        self.conv = nn.Conv2d(in_channels, 2, kernel_size=1)
        
        # Vettori normalizzati da apprendere (13 vettori di 2 dimensioni)
        self.learnable_vectors = nn.Parameter(torch.randn(13, 2))

    def forward(self, x):
        # Riduci i canali a 2 e normalizza
        x = self.conv(x)  # (b, 2, h, w)
        x = F.normalize(x, p=2, dim=1)

        # Normalizza i vettori appresi direttamente nel forward
        normalized_vectors = F.normalize(self.learnable_vectors, p=2, dim=1)

        # Calcolo delle distanze rispetto ai vettori appresi
        b, c, h, w = x.shape

        # Risagomare per calcolare la distanza
        x_flat = x.view(b, c, -1).permute(0, 2, 1)  # (b, h*w, 2)

        # Calcola la distanza per ciascun vettore
        distances = []
        for v in normalized_vectors:
            v = v.unsqueeze(0).expand(b, -1)
            dist = torch.norm(x_flat - v, dim=2)
            distances.append(dist.view(b, h, w))

        # Stack delle distanze (b, 13, h, w)
        distances = torch.stack(distances, dim=1)
        return distances

# Esempio di utilizzo
model = MetricHead(in_channels=3)
x = torch.randn(1, 3, 32, 32)
distances = model(x)
print(distances.shape)

torch.Size([1, 1024, 2])
torch.Size([1, 1024, 2])
torch.Size([1, 1024, 2])
torch.Size([1, 1024, 2])
torch.Size([1, 1024, 2])
torch.Size([1, 1024, 2])
torch.Size([1, 1024, 2])
torch.Size([1, 1024, 2])
torch.Size([1, 1024, 2])
torch.Size([1, 1024, 2])
torch.Size([1, 1024, 2])
torch.Size([1, 1024, 2])
torch.Size([1, 1024, 2])
torch.Size([1, 13, 32, 32])


In [2]:
import os

In [74]:
def find_metrics(log_file):
    metrics = {}

    for line in log_file.readlines():
        if len(line):
            tokens = line.split()
            
            # IoU metrics
            if 'mIoU_no_zero:' in tokens:
                metrics['mIoU'] = float(tokens[4])*100
            elif 'classe' in tokens and '3:' in tokens:
                metrics['minIoU'] = float(tokens[2])*100
            elif 'maxIoU:' in tokens:
                metrics['maxIoU'] = float(tokens[4])*100
            # anomaly metrics
            if 'AUPR:' in tokens:
                metrics[tokens[3].lower()] = float(tokens[5])*100
    return metrics

def find_parameters(log_file):
    params = {}

    flag = True
    for line in log_file.readlines():
        if len(line):
            tokens = line.split()
            
            if flag:
                if 'Dataset' in tokens:
                    flag = False
                else:
                    params[tokens[3]] = tokens[5]
            else:
                break
    return params

def prepare_metrics(metrics):
    stringa = ''
    stringa += f"{round(metrics['mIoU'],2)};"
    stringa += f"{round(metrics['minIoU'],2)};"
    stringa += f"{round(metrics['maxIoU'],2)};"
    stringa += f"{round(metrics['msp'],2)};"
    stringa += f"{round(metrics['maxlog'],2)};"
    stringa += f"{round(metrics['entropy'],2)};"
    stringa += f"{round(metrics['energy'],2)}"
    return stringa.replace('.',',')

def prepare_parameters(parameters):
    mappa = {'True': '✓', 'False': '𐄂'}

    stringa = ''
    stringa += f"{mappa[parameters['biases']]};"
    stringa += f"{mappa[str(bool(int(parameters['focal_loss'])))]};"
    stringa += f"{mappa[parameters['class_weights']]};"
    stringa += f"{mappa[parameters['norm_weights']]};"
    try:
        stringa += f"{mappa[str(bool(int(parameters['mlp'])))]};"
    except:
        stringa += f"{mappa['False']};"
    stringa += f"{parameters['activation']};"
    return stringa

In [86]:
folder = '/raid/homespace/piecestola/space/ML4CV/results_6'

texts = list()
order = list()
for file in os.listdir(folder):
    try:
        path_train = os.path.join(folder,file,'train.log')
        path_test = os.path.join(folder,file,'test.log')
        
        params = find_parameters(open(path_train,'r'))
        metrics = find_metrics(open(path_test,'r'))

        params = prepare_parameters(params)
        metrics = prepare_metrics(metrics)

        order.append(int(file.split('_')[1]))

        text = params+metrics
        text = text.replace(';',' & ')
        text = text.replace('softmax','\\text{softmax}')
        text = text.replace('sigmoid','\\text{sigmoid}')
        text += ' \\\\'
        texts.append(text)
    except:
        print("Skipped",file)
print()

texts = [v for _, v in sorted(zip(order, texts))]

print('\\begin{array}{|cccccc|ccc|cccccc|}')
print('\\hline')
print('\\text{Biases} & \\text{Focal Loss} & \\text{Class weights} & \\text{Norm weights} & \\text{MLP} & \\text{Activation} & \\text{mIoU} & \\text{minIoU} & \\text{maxIoU} & \\text{msp} & \\text{maxlogit} & \\text{entropy} & \\text{energy} \\\\')
print('\\hline')

maxs = [0,0,0,0,0,0,0]
maxs_id = [0,0,0,0,0,0,0]
for i, text in enumerate(texts):
    tokens = text.split()
    for j in range(12,26,2):
        temp = float(tokens[j].replace(',','.'))
        if temp > maxs[j//2-6]:
            maxs[j//2-6] = temp
            maxs_id[j//2-6] = i

for i, text in enumerate(texts):
    tokens = text.split()
    for k, id in enumerate(maxs_id):
        if id == i:
            tokens[(k+6)*2] = '\\underline{'+tokens[(k+6)*2]+'}'
    print(' '.join(tokens))

print('\\hline')
print('\\end{array}')


\begin{array}{|cccccc|ccc|cccccc|}
\hline
\text{Biases} & \text{Focal Loss} & \text{Class weights} & \text{Norm weights} & \text{MLP} & \text{Activation} & \text{mIoU} & \text{minIoU} & \text{maxIoU} & \text{msp} & \text{maxlogit} & \text{entropy} & \text{energy} \\
\hline
𐄂 & 𐄂 & 𐄂 & 𐄂 & 𐄂 & \text{softmax} & 32,93 & 1,8 & 72,74 & 4,22 & 6,34 & 7,33 & 6,23 \\
𐄂 & 𐄂 & 𐄂 & 𐄂 & 𐄂 & \text{softmax} & 64,09 & 9,01 & 97,24 & 11,61 & 15,06 & 14,04 & 15,15 \\
𐄂 & 𐄂 & 𐄂 & 𐄂 & 𐄂 & \text{softmax} & 27,97 & 3,08 & 71,76 & 3,41 & 4,73 & 5,33 & 4,72 \\
𐄂 & 𐄂 & 𐄂 & 𐄂 & 𐄂 & \text{softmax} & \underline{65,58} & 10,97 & \underline{97,6} & 10,88 & 14,66 & 13,5 & 14,78 \\
𐄂 & 𐄂 & 𐄂 & 𐄂 & 𐄂 & \text{softmax} & 64,42 & 10,48 & 96,86 & 12,1 & 16,15 & 15,61 & 16,37 \\
𐄂 & 𐄂 & 𐄂 & 𐄂 & 𐄂 & \text{softmax} & 61,11 & 8,97 & 96,23 & 10,17 & 11,66 & 12,04 & 11,22 \\
𐄂 & 𐄂 & 𐄂 & 𐄂 & 𐄂 & \text{softmax} & 63,33 & 10,24 & 96,69 & \underline{12,5} & \underline{18,16} & \underline{16,01} & \underline{18,81} \\
𐄂 & 𐄂 & 𐄂 & 