In [2]:
# !wget https://github.com/marcin119a/data/raw/refs/heads/main/data_gsn.zip
# !unzip data_gsn.zip &> /dev/null
# !rm data_gsn.zip

In [None]:
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
import pandas as pd
from PIL import Image
import numpy as np
import os
from torchvision import transforms
import torch.nn.functional as F
from sklearn.metrics import f1_score, confusion_matrix
import pickle

torch.manual_seed(1)

<torch._C.Generator at 0x7f1928071770>

In [4]:
all_pairs = [(i, j) for i in range(6) for j in range(i + 1, 6)]
pair_to_idx = {p: k for k, p in enumerate(all_pairs)}

def class_id_to_pair_and_split(class_id: int):
    pair_idx = class_id // 9
    split_idx = class_id % 9
    ca = split_idx + 1
    cb = 10 - ca
    i, j = all_pairs[pair_idx]
    return (i, j), (ca, cb)

def class_id_to_pair(class_id: int):
    pair_idx = class_id // 9
    return all_pairs[pair_idx]

def counts_to_class_id(counts):
    if isinstance(counts, torch.Tensor):
        c = counts.detach().cpu().tolist()
    else:
        c = list(counts)

    nz = [i for i, v in enumerate(c) if v > 0]
  
    a, b = sorted(nz)
    ca = int(c[a])
    pair_index = pair_to_idx[(a, b)]

    class_id = pair_index * 9 + (ca - 1)
    return class_id

In [5]:
class GSN(Dataset):
    def __init__(self, root, transform=None, transform_relabel=None):
        self.data_dir = os.path.join(root, "data")
        self.transform = transform
        self.transform_relabel = transform_relabel

        df = pd.read_csv(os.path.join(self.data_dir, "labels.csv"))
        self.names = df["name"].tolist()
        cols = ["squares", "circles", "up", "right", "down", "left"]
        self.labels = torch.tensor(df[cols].values, dtype=torch.float32)

    def __len__(self):
        return len(self.names)
    
    def __getitem__(self, index):
        name = self.names[index]
        img_path = os.path.join(self.data_dir, name)

        img = Image.open(img_path).convert("L")
        img = transforms.ToTensor()(img)
        
        if self.transform:
            img = self.transform(img)

        cnt = self.labels[index]

        if self.transform_relabel:
            img, cnt = self.transform_relabel(img, cnt)
        
        cls = counts_to_class_id(cnt)

        return img, cls, cnt

In [6]:
class Augmentation:
    def __init__(self, p_hflip=0.5, p_vflip=0.5):
        self.p_hflip = p_hflip
        self.p_vflip = p_vflip

    def __call__(self, img, cnt):
        cnt = cnt.clone()

        k = torch.randint(0, 4, (1,)).item()
        if k > 0:
            img = torch.rot90(img, k=-k, dims=[1,2])
            dirs = cnt[2:6]
            dirs = torch.roll(dirs, shifts=k)
            cnt[2:6] = dirs

        if torch.rand(1).item() < self.p_hflip:
            img = torch.flip(img, dims=[2])
            cnt[[3, 5]] = cnt[[5, 3]]
        
        if torch.rand(1).item() < self.p_vflip:
            img = torch.flip(img, dims=[1])
            cnt[[2, 4]] = cnt[[4, 2]]

        return img, cnt

In [7]:
class NeuralNetwork(nn.Module):
    def __init__(self, cls_hidden=256, cnt_hidden=256, dropout=0.3):
        super().__init__()

        self.backbone = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=1, padding=1), nn.ReLU(),     
            nn.Conv2d(8, 16, 3, stride=1, padding=1), nn.ReLU(),   
            nn.Conv2d(16, 32, 3, stride=1, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),
            nn.Flatten(start_dim=1),
            nn.Linear(64 * 28 * 28, 256), nn.ReLU()
        )

        self.head_cls = nn.Sequential(
            nn.Linear(256, cls_hidden),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(cls_hidden, 135),
            nn.LogSoftmax(dim=1)
        )

        self.head_cnt = nn.Sequential(
            nn.Linear(256, cnt_hidden),
            nn.ReLU(),
            nn.Linear(cnt_hidden, 6)
        )
    
    def forward(self, x):
        x = self.backbone(x)

        cls = self.head_cls(x)
        cnt = self.head_cnt(x)

        return cls, cnt

In [8]:
def train_epoch(
    net: torch.nn.Module,
    device: torch.device,
    train_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    mode: str,
    lambda_cnt = None,
):
    net.train()
    total_loss = 0.0
    n_total = 0
    
    for img, cls_target, cnt_target in train_loader:
        img, cls_target, cnt_target = img.to(device), cls_target.long().to(device), cnt_target.to(device)

        optimizer.zero_grad()

        cls_pred, cnt_pred = net(img)
        
        cls_loss = F.nll_loss(cls_pred, cls_target)
        cnt_loss = F.smooth_l1_loss(cnt_pred, cnt_target)

        if mode == "cls_only":
            loss = cls_loss
        elif mode == "reg_only":
            loss = cnt_loss
        else:
            loss = cls_loss + lambda_cnt * cnt_loss

        loss.backward()
        optimizer.step()

        B = len(img)

        total_loss += loss.item() * B
        n_total += B

    epoch_loss = total_loss / n_total
    return epoch_loss


In [9]:
def eval_epoch(
    net: torch.nn.Module,
    device: torch.device,
    test_loader: torch.utils.data.DataLoader,
    epoch: int,
    mode: str,
    lambda_cnt = None,
    verbose: bool = False,
):
    net.eval()
    total_loss = total_cls_loss = total_cnt_loss = 0.0
    n_total = 0
    correct = 0

    sum_sq_diff = 0.0
    n_cnt = 0

    with torch.no_grad():
        for img, cls_target, cnt_target in test_loader:
            img, cls_target, cnt_target = img.to(device), cls_target.long().to(device), cnt_target.to(device)

            cls_pred, cnt_pred = net(img)

            cls_loss = F.nll_loss(cls_pred, cls_target)
            cnt_loss = F.smooth_l1_loss(cnt_pred, cnt_target)

            if mode == "cls_only":
                loss = cls_loss
            elif mode == "reg_only":
                loss = cnt_loss
            else:
                loss = cls_loss + lambda_cnt * cnt_loss

            total_loss += loss.item() * len(img)
            total_cls_loss += cls_loss.item() * len(img)
            total_cnt_loss += cnt_loss.item() * len(img)
            n_total += len(img)

            pred = cls_pred.argmax(dim=1)
            correct += (pred == cls_target).sum().item()

            diff = cnt_pred - cnt_target
            sum_sq_diff += (diff ** 2).sum().item()
            n_cnt += diff.numel()

    epoch_loss = total_loss / n_total
    epoch_cls_loss = total_cls_loss / n_total
    epoch_cnt_loss = total_cnt_loss / n_total
    epoch_acc = correct / n_total

    epoch_rmse = (sum_sq_diff / n_cnt) ** 0.5

    if verbose:
        print(
            f"Eval Epoch: {epoch} | "
            f"acc: {epoch_acc:.4f} | "
            f"loss: {epoch_loss:.4f} | "
            f"cls_loss: {epoch_cls_loss:.4f} | "
            f"cnt_loss: {epoch_cnt_loss:.4f}"
        )
    return epoch_loss, epoch_acc, epoch_rmse

In [10]:
def evaluate_metrics(net, device, loader):
    net.eval()

    all_cls_true = []
    all_cls_pred = []
    all_cnt_true = []
    all_cnt_pred = []

    with torch.no_grad():
        for img, cls_target, cnt_target in loader:
            img = img.to(device)

            cls_logits, cnt_pred = net(img)
            cls_pred = cls_logits.argmax(dim=1)

            all_cls_true.append(cls_target)
            all_cls_pred.append(cls_pred.to("cpu"))            
            all_cnt_true.append(cnt_target)
            all_cnt_pred.append(cnt_pred.to("cpu"))

    cls_true = torch.cat(all_cls_true)
    cls_pred = torch.cat(all_cls_pred)
    cnt_true = torch.cat(all_cnt_true)
    cnt_pred = torch.cat(all_cnt_pred)

    acc = (cls_true == cls_pred).float().mean().item()
    print(f"Top-1 accuracy: {acc}" )  

    cls_true_np = cls_true.numpy() 
    cls_pred_np = cls_pred.numpy() 
    macro_f1 = f1_score(cls_true_np, cls_pred_np, average="macro")
    print(f"Macro F1:{macro_f1}")

    correct_pair = 0
    total = len(cls_true)
    for t, p in zip(cls_true.tolist(), cls_pred.tolist()):
        if class_id_to_pair(int(t)) == class_id_to_pair(int(p)):
            correct_pair += 1
    pair_acc = correct_pair / total
    print(f"Per-pair accuracy: {pair_acc}")

    diff = cnt_pred - cnt_true
    mse_per_class = (diff ** 2).mean(dim=0)
    rmse_per_class = torch.sqrt(mse_per_class)
    mae_per_class = diff.abs().mean(dim=0)
    overall_rmse = torch.sqrt((diff ** 2).mean()).item()
    overall_mae = diff.abs().mean().item()  
    print(f"RMSE per class: {rmse_per_class.tolist()}")
    print(f"MAE per class: {mae_per_class.tolist()}")
    print(f"Overall RMSE: {overall_rmse}")
    print(f"Overall MAE: {overall_mae}")

    cm = confusion_matrix(cls_true_np, cls_pred_np, labels=np.arange(135))

    return {
        "acc": acc,
        "macro_f1": macro_f1,
        "pair_acc": pair_acc,
        "rmse_per_class": rmse_per_class.tolist(),
        "mae_per_class": mae_per_class.tolist(),
        "overall_rmse": overall_rmse,
        "overall_mae": overall_mae,
        "confusion_matrix": cm,
    }

In [11]:
def create_loaders(root=".", batch_size=64, test_batch_size=1000, device=torch.device("cpu")):
    if device.type == "cuda":
        num_workers = min(8, os.cpu_count() or 2)
    else:
        num_workers = 0

    pin = (device.type == "cuda")
    loader_kwargs = dict(
        num_workers=num_workers,
        pin_memory=pin,
        persistent_workers=False,
    )
    if num_workers > 0:
        loader_kwargs["prefetch_factor"] = 4

    train_aug = Augmentation()
    train_full = GSN(root=root, transform_relabel=train_aug)
    test_full = GSN(root=root)

    train_dataset = Subset(train_full, range(0, 9000))
    test_dataset = Subset(test_full, range(9000, 10000))

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs)
    test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, **loader_kwargs)
    return train_loader, test_loader


In [12]:
epochs = 100
lr = 1e-3
batch_size = 64
test_batch_size = 1000

cls_hidden = 256
cnt_hidden = 256
dropout = 0.3
patience = 10
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")

def train_model(
    mode: str,
    lambda_cnt = None,
):
    if device.type == "cuda":
        torch.backends.cudnn.conv.fp32_precision = 'tf32'


    train_loader, test_loader = create_loaders(
        root=".",
        batch_size=64,
        test_batch_size=1000,
        device=device,
    )

    net = NeuralNetwork(cls_hidden, cnt_hidden, dropout).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    history = {
        "train_loss": [],
        "eval_loss": [],
        "eval_acc": [],
        "eval_rmse": [],
    }

    best_eval_loss = float("inf")
    best_eval_loss_acc = 0.0
    best_state = None
    bad_epochs = 0
    best_epoch = 0

    for epoch in range(1, epochs+1):
        train_loss = train_epoch(
            net,
            device,
            train_loader,
            optimizer,
            mode,
            lambda_cnt=lambda_cnt,
        )

        eval_loss, eval_acc, eval_rmse = eval_epoch(
            net,
            device,
            test_loader,
            epoch,
            mode,
            lambda_cnt=lambda_cnt,
            verbose=True
        )

        history["train_loss"].append(train_loss)
        history["eval_loss"].append(eval_loss)
        history["eval_acc"].append(eval_acc)
        history["eval_rmse"].append(eval_rmse)

        if eval_loss < best_eval_loss - 1e-4:
            best_eval_loss = eval_loss
            best_eval_loss_acc = eval_acc
            best_state = {k: v.cpu().clone() for k,v in net.state_dict().items()}
            bad_epochs = 0
            best_epoch = epoch
        else:
            bad_epochs += 1
            if bad_epochs >= patience:
                print(
                    f"Early stop at epoch {epoch}. "
                    f"Best val loss: {best_eval_loss:.4f}, "
                    f"with acc: {best_eval_loss_acc:.4f}, "
                    f"after epoch: {best_epoch}"
                )
                break

    if best_state is not None:
        net.load_state_dict(best_state)

    metrics = evaluate_metrics(net, device, test_loader)

    return net, history, metrics

In [14]:
def train_multiple_models(settings, save_dir=None):
    results = {}

    for mode, lambda_cnt in settings:
        if lambda_cnt is not None:
            print(f"\n=== Training mode={mode}, lambda_cnt={lambda_cnt} ===")
            net, history, metrics = train_model(mode=mode, lambda_cnt=lambda_cnt)

            results[(mode, lambda_cnt)] = {
                "history": history,
                "metrics": metrics,
            }

            lambda_str = str(lambda_cnt).replace(".", "_")
            model_path = os.path.join(save_dir, f"model_{mode}_lambda{lambda_str}.pt")
            torch.save(net.state_dict(), model_path)
            print(f"Saved model to {model_path}")
            
        else:
            print(f"\n=== Training mode={mode} ===")
            net, history, metrics = train_model(mode=mode)

            results[mode] = {
                "history": history,
                "metrics": metrics,
            }

            model_path = os.path.join(save_dir, f"model_{mode}.pt")
            torch.save(net.state_dict(), model_path)
            print(f"Saved model to {model_path}")

    results_path = os.path.join(save_dir, "results.pkl")
    with open(results_path, "wb") as f:
        pickle.dump(results, f)
    print(f"\nSaved all results dict to {results_path}")

    return results

In [15]:
settings = [
    ("cls_only", None),
    ("reg_only", None),
    ("multitask", 0.3),
    ("multitask", 0.5),
    ("multitask", 1.0),
]

train_multiple_models(settings, "models")


=== Training mode=cls_only ===
Eval Epoch: 1 | acc: 0.0050 | loss: 4.6753 | cls_loss: 4.6753 | cnt_loss: 1.4908
Eval Epoch: 2 | acc: 0.0130 | loss: 4.6208 | cls_loss: 4.6208 | cnt_loss: 1.5008
Eval Epoch: 3 | acc: 0.0620 | loss: 3.6990 | cls_loss: 3.6990 | cnt_loss: 1.5497
Eval Epoch: 4 | acc: 0.1420 | loss: 2.8057 | cls_loss: 2.8057 | cnt_loss: 1.6158
Eval Epoch: 5 | acc: 0.2240 | loss: 2.3039 | cls_loss: 2.3039 | cnt_loss: 1.6251
Eval Epoch: 6 | acc: 0.2360 | loss: 2.1355 | cls_loss: 2.1355 | cnt_loss: 1.6422
Eval Epoch: 7 | acc: 0.3110 | loss: 1.8019 | cls_loss: 1.8019 | cnt_loss: 1.6492
Eval Epoch: 8 | acc: 0.3480 | loss: 1.6831 | cls_loss: 1.6831 | cnt_loss: 1.6496
Eval Epoch: 9 | acc: 0.3790 | loss: 1.5597 | cls_loss: 1.5597 | cnt_loss: 1.6477
Eval Epoch: 10 | acc: 0.3540 | loss: 1.5707 | cls_loss: 1.5707 | cnt_loss: 1.6382
Eval Epoch: 11 | acc: 0.3690 | loss: 1.5218 | cls_loss: 1.5218 | cnt_loss: 1.6609
Eval Epoch: 12 | acc: 0.3870 | loss: 1.4730 | cls_loss: 1.4730 | cnt_loss: 

{'cls_only': {'history': {'train_loss': [4.727255440606012,
    4.67275235748291,
    4.222638869815403,
    3.3035047062767875,
    2.7288177575005426,
    2.3264633657667373,
    2.0876981253094145,
    1.9049107961654663,
    1.734792820294698,
    1.6766099809010824,
    1.6101154088974,
    1.5339518444273208,
    1.4821877431869508,
    1.4662117142147488,
    1.409798386891683,
    1.3554782410727606,
    1.3567470243242052,
    1.309702552901374,
    1.2928551042344836,
    1.2640149385664199,
    1.241083093855116,
    1.217800605032179,
    1.2073084926605224,
    1.181831755426195,
    1.142301187409295,
    1.1054777488178678,
    1.1021307236353557,
    1.0888401051627266,
    1.0770724494722155,
    1.0430111405054727,
    1.008509994453854,
    1.0139290359285142,
    1.0060970865885417,
    0.9810530780686273,
    0.9702429509692722,
    0.9447353371514214,
    0.9354657807879978,
    0.9037643768522474,
    0.8888993995984396,
    0.9066486066182454,
    0.873497624503