# TLP experimental framework

In [1]:
root_dir = "./TLP"

In [None]:
# regularization
# Decrese batch. Layer 2 only
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import classification_report
from tqdm import tqdm
import numpy as np
import pandas as pd
import pickle
import os
from datetime import datetime
import json
import time
import uuid
from copy import deepcopy

# stage
use_restored_weights = False
inverse = False
save_dir = f"{root_dir}/raw_weights_stl10_3_7"
restore_dir = f"{root_dir}/raw_weights_stl10_3_1"
method_description = (
    "STL10, 2 layer TLP. Regularisation. 'dirbias': True,'entropy': False,'var': True,'inverse': True,'lambda_dirbias': 0.1,'lambda_entropy': 0.01,'lambda_var': 0.1"
)
if use_restored_weights:
    method_description = "STL10, 2 layer TLP. TLP frozen, retrain conv2"

device_type = "cuda" if torch.cuda.is_available() else "cpu"
print(device_type)
device = torch.device(device_type)

# Load STL-10 dataset
transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomCrop(96, padding=4), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

if os.path.exists("data_cache/trainset.pkl") and os.path.exists("data_cache/testset.pkl"):
    print("Load train data from cache...")
    with open("data_cache/trainset.pkl", "rb") as f:
        trainset = pickle.load(f)
    with open("data_cache/testset.pkl", "rb") as f:
        testset = pickle.load(f)
else:
    print("Load and create datasets...")
    trainset = datasets.STL10(root="./data", split="train", download=True, transform=transform)
    testset = datasets.STL10(root="./data", split="test", download=True, transform=transform)

    # Сохраняем для последующих запусков
    os.makedirs("data_cache", exist_ok=True)
    with open("data_cache/trainset.pkl", "wb") as f:
        pickle.dump(trainset, f)
    with open("data_cache/testset.pkl", "wb") as f:
        pickle.dump(testset, f)

trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
testloader = DataLoader(testset, batch_size=32, shuffle=False)


def save_raw_weights(model, epoch, epochs, save_dir="raw_weights_stl10_2_2"):
    os.makedirs(save_dir, exist_ok=True)
    weights = {}

    if isinstance(model.pool2, TropicalMaxPool2d):
        weights["pool2_weights"] = model.pool2.weights.detach().cpu().numpy()

    if hasattr(model, "conv2"):
        weights["conv2_kernels"] = model.conv2.weight.detach().cpu().numpy()

    with open(os.path.join(save_dir, f"epoch_{epoch + 1:02d}.pkl"), "wb") as f:
        pickle.dump(weights, f)

    if epoch + 1 == epochs:
        torch.save(model.state_dict(), os.path.join(save_dir, f"epoch_{epoch + 1:02d}_state_dict.pt"))
        torch.save(model, os.path.join(save_dir, f"epoch_{epoch + 1:02d}_full_model.pt"))


def log_experiment_results(uid, method_description, report_dict, runtime_seconds, test_loss, accuracy, save_dir, reg_opts=None):
    LOG_FILE = f"{root_dir}/log_TLP.jsonl"
    record = {
        "uid": uid,
        "date": datetime.now().isoformat(),
        "method": method_description,
        "accuracy": report_dict.get("accuracy", None),
        "macro_f1": report_dict.get("macro avg", {}).get("f1-score", None),
        "weighted_f1": report_dict.get("weighted avg", {}).get("f1-score", None),
        "runtime_sec": runtime_seconds,
        "test_loss": test_loss,
        "save_dir": save_dir,
        "reg_opts": reg_opts or {},
    }
    os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True)
    with open(LOG_FILE, "a") as f:
        f.write(json.dumps(record) + "\n")


def load_tlp_weights(model, weights_path):
    with open(weights_path, "rb") as f:
        weights = pickle.load(f)
    tlp_weights = torch.tensor(weights["pool2_weights"])
    with torch.no_grad():
        model.pool2_tlp.weights.copy_(tlp_weights)
    model.pool2_tlp.weights.requires_grad = False  # замораживаем


def display_jsonl_as_table(sort_by="date", descending=True):
    LOG_FILE = f"{root_dir}/log_TLP.jsonl"
    with open(LOG_FILE, "r") as f:
        lines = f.readlines()
    df = pd.DataFrame([json.loads(line.strip()) for line in lines])
    if sort_by in df.columns:
        df = df.sort_values(by=sort_by, ascending=not descending)
    # from IPython.display import display
    # display(df)
    return df


# Learnable Tropical Pooling
class TropicalMaxPool2d(nn.Module):
    def __init__(self, channels, kernel_size=2, stride=2, padding=0):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.weights = nn.Parameter(torch.randn(channels, kernel_size * kernel_size) * 0.01)

    def forward(self, x):
        B, C, H, W = x.size()
        x_padded = nn.functional.pad(x, (self.padding,) * 4)
        unfolded = nn.functional.unfold(x_padded, kernel_size=self.kernel_size, stride=self.stride)
        unfolded = unfolded.view(B, C, self.kernel_size * self.kernel_size, -1)
        weighted = unfolded + self.weights.view(1, C, -1, 1)
        pooled = weighted.max(dim=2)[0]
        out_h = (H + 2 * self.padding - self.kernel_size) // self.stride + 1
        out_w = (W + 2 * self.padding - self.kernel_size) // self.stride + 1
        return pooled.view(B, C, out_h, out_w)


# MaxPool wrapper
class StandardMaxPool2d(nn.Module):
    def __init__(self, channels, kernel_size=2, stride=2, padding=0):
        super().__init__()
        self.pool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        return self.pool(x)


# Shared architecture
class ConvNetSTL10(nn.Module):
    def __init__(self, pool_cls, in_channels=3, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = pool_cls(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # After 3 × 2x2 poolings on 96×96 → 12×12
        self.fc1 = nn.Linear(128 * 12 * 12, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = self.pool3(torch.relu(self.conv3(x)))
        x = x.flatten(1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)


# === Регуляризация TLP ===
def compute_alignment_regularization(
    conv_feats, tlp_weights, reg_dirbias=True, reg_entropy=True, reg_var=True, inverse=False, lambda_dirbias=0.1, lambda_entropy=0.05, lambda_var=0.05
):
    reg_loss = 0.0
    B, C, H, W = conv_feats.shape
    conv_flat = conv_feats.view(B, C, -1).mean(-1)  # [B, C]
    conv_norm = conv_flat / (conv_flat.norm(dim=1, keepdim=True) + 1e-6)  # [B, C]

    tlp_norm = tlp_weights / (tlp_weights.norm(dim=1, keepdim=True) + 1e-6)  # [C, 4]
    tlp_agg = tlp_norm.mean(dim=1)  # [C]
    tlp_agg = tlp_agg / (tlp_agg.norm() + 1e-6)  # [C]

    dot = (conv_norm * tlp_agg.unsqueeze(0)).sum(1)  # [B]

    reg_components = []

    if reg_dirbias:
        bias = dot.mean()
        reg_components.append(((-1 if inverse else 1) * lambda_dirbias * bias.abs()).clamp(min=0.0))

    if reg_var:
        var = dot.var()
        reg_components.append(((-1 if inverse else 1) * lambda_var * var).clamp(min=0.0))

    if reg_entropy:
        p = torch.softmax(tlp_weights, dim=1)
        entropy = -(p * torch.log(p + 1e-6)).sum(dim=1).mean()
        reg_components.append(((-1 if inverse else 1) * lambda_entropy * entropy).clamp(min=0.0))

    if reg_components:
        reg_loss = sum(reg_components)

    return reg_loss


# === Метод для вытягивания признаков conv2 ===
def extract_conv2_feats(model, x):
    x = model.pool1(torch.relu(model.conv1(x)))
    conv2 = torch.relu(model.conv2(x))
    return conv2.detach()


# Training and evaluation
def train_model(model, loader, criterion, optimizer, epochs=5, save_weights=False, reg_opts=None):
    model.train()
    for epoch in range(epochs):
        total, correct, loss_sum, base_sum, reg_sum = 0, 0, 0.0, 0.0, 0.0
        for inputs, labels in tqdm(loader, desc=f"Epoch {epoch+1}"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            base_loss = criterion(outputs, labels)
            loss = base_loss

            if reg_opts is not None and hasattr(model.pool2, "weights"):
                with torch.no_grad():
                    conv_feats = extract_conv2_feats(model, inputs)
                reg_loss = compute_alignment_regularization(
                    conv_feats,
                    model.pool2.weights,
                    reg_dirbias=reg_opts.get("dirbias", False),
                    reg_entropy=reg_opts.get("entropy", False),
                    reg_var=reg_opts.get("var", False),
                    inverse=reg_opts.get("inverse", False),
                    lambda_dirbias=reg_opts.get("lambda_dirbias", 0.1),
                    lambda_entropy=reg_opts.get("lambda_entropy", 0.05),
                    lambda_var=reg_opts.get("lambda_var", 0.05),
                )
                loss = base_loss + reg_loss
                reg_sum += reg_loss.item()

            loss.backward()
            optimizer.step()

            loss_sum += loss.item()
            base_sum += base_loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        print(
            f"Epoch {epoch}: Base Loss = {base_sum / len(loader):.4f}, Reg Loss = {reg_sum / len(loader):.4f}, Total Loss = {loss_sum / len(loader):.4f} | Accuracy: {100 * correct / total:.2f}%"
        )

        if save_weights:
            save_raw_weights(model, epoch, epochs, save_dir=save_dir)


def evaluate_model(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            _, preds = outputs.max(1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(loader)
    accuracy = np.mean(np.array(all_preds) == np.array(all_labels))
    report = classification_report(all_labels, all_preds, output_dict=True)
    report["accuracy"] = accuracy

    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {100 * accuracy:.2f}%")
    return report, avg_loss, accuracy


import itertools
import uuid
import json
import time
from datetime import datetime
import os

# Примерные базовые значения
lambda_values = [0.0, 0.01, 0.05, 0.1]
boolean_flags = [(True, False)]

# Сетка параметров
parameter_grid = list(
    itertools.product(
        boolean_flags[0],  # dirbias
        boolean_flags[0],  # entropy
        boolean_flags[0],  # var
        lambda_values,  # lambda_dirbias
        lambda_values,  # lambda_entropy
        lambda_values,  # lambda_var
    )
)

# Генерация конфигураций
experiment_configs = []
for dirbias, entropy, var, l_dirbias, l_entropy, l_var in parameter_grid:
    config = {"dirbias": dirbias, "entropy": entropy, "var": var, "inverse": False, "lambda_dirbias": l_dirbias, "lambda_entropy": l_entropy, "lambda_var": l_var}
    experiment_configs.append(config)


# Пример логирования
def log_experiment_results(uid, method_description, report_dict, runtime_seconds, test_loss, accuracy, save_dir, reg_opts=None):
    LOG_FILE = f"{root_dir}/log_TLP.jsonl"
    record = {
        "uid": uid,
        "date": datetime.now().isoformat(),
        "method": method_description,
        "accuracy": report_dict.get("accuracy", None),
        "macro_f1": report_dict.get("macro avg", {}).get("f1-score", None),
        "weighted_f1": report_dict.get("weighted avg", {}).get("f1-score", None),
        "runtime_sec": runtime_seconds,
        "test_loss": test_loss,
        "save_dir": save_dir,
        "reg_opts": reg_opts or {},
    }
    os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True)
    with open(LOG_FILE, "a") as f:
        f.write(json.dumps(record) + "\n")


start_from = 0  # Last folder name
for i, config in enumerate(experiment_configs):
    if i < start_from:
        continue
    save_dir = f"{root_dir}/raw_weights_stl10_3_8_{i}"
    print(f"\n=== Experiment {i+1}/{len(experiment_configs)} ===")
    uid = str(uuid.uuid4())
    method_description = f"ablation_TLP_{i:03d}"

    # Инициализация модели и оптимизатора
    model = ConvNetSTL10(TropicalMaxPool2d).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    # Запуск обучения
    start_time = time.time()
    train_model(model, trainloader, criterion, optimizer, epochs=40, save_weights=True, reg_opts=deepcopy(config))
    report_dict, test_loss, accuracy = evaluate_model(model, testloader, criterion)

    # Логирование
    log_experiment_results(
        uid=uid,
        method_description=method_description,
        report_dict=report_dict,
        runtime_seconds=round(time.time() - start_time, 2),
        test_loss=test_loss,
        accuracy=accuracy,
        save_dir=save_dir,
        reg_opts=config,
    )


display_jsonl_as_table()

cpu
Load and create datasets...

=== Experiment 1/512 ===


Epoch 1: 100%|██████████| 157/157 [00:39<00:00,  3.97it/s]


Epoch 0: Base Loss = 1.8058, Reg Loss = 0.0000, Total Loss = 1.8058 | Accuracy: 32.86%


Epoch 2: 100%|██████████| 157/157 [00:39<00:00,  3.99it/s]


Epoch 1: Base Loss = 1.4229, Reg Loss = 0.0000, Total Loss = 1.4229 | Accuracy: 47.62%


Epoch 3: 100%|██████████| 157/157 [00:43<00:00,  3.61it/s]


Epoch 2: Base Loss = 1.2833, Reg Loss = 0.0000, Total Loss = 1.2833 | Accuracy: 52.30%


Epoch 4: 100%|██████████| 157/157 [00:39<00:00,  4.01it/s]


Epoch 3: Base Loss = 1.1476, Reg Loss = 0.0000, Total Loss = 1.1476 | Accuracy: 57.72%


Epoch 5:  39%|███▉      | 62/157 [00:15<00:31,  3.05it/s]