In [14]:
# !wget https://github.com/marcin119a/data/raw/refs/heads/main/data_gsn.zip
# !unzip data_gsn.zip &> /dev/null
# !rm data_gsn.zip

In [15]:
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import pandas as pd
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt
from torch import Tensor
from torchvision import transforms
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)


<torch._C.Generator at 0x7fa94094e490>

In [16]:
PAIRS = [(i, j) for i in range(6) for j in range(i + 1, 6)]  # 15 unordered pairs
PAIR_TO_IDX = {p: k for k, p in enumerate(PAIRS)}
N_CONFIGS = len(PAIRS) * 9  # 135

def class_id_to_pair_and_split(class_id: int):
    pair_idx = class_id // 9
    split_idx = class_id % 9  # 0..8 -> counts 1..9
    ca = split_idx + 1
    cb = 10 - ca
    i, j = PAIRS[pair_idx]
    return (i, j), (ca, cb)

def class_id_to_pair(class_id: int):
    pair_idx = class_id // 9
    return PAIRS[pair_idx]

In [17]:
class GSN(Dataset):
    def __init__(self, root, transform=None, transform_relabel=None):
        self.data_dir = os.path.join(root, "data")
        self.transform = transform
        self.transform_relabel = transform_relabel

        df = pd.read_csv(os.path.join(self.data_dir, "labels.csv"))
        self.names = df["name"].tolist()
        cols = ["squares", "circles", "up", "right", "down", "left"]
        self.labels = torch.tensor(df[cols].values, dtype=torch.float32)

    def __len__(self):
        return len(self.names)
    
    def __getitem__(self, index):
        name = self.names[index]
        img_path = os.path.join(self.data_dir, name)

        img = Image.open(img_path).convert("L")
        img = transforms.ToTensor()(img)
        
        if self.transform:
            img = self.transform(img)

        cnt = self.labels[index]

        if self.transform_relabel:
            img, cnt = self.transform_relabel(img, cnt)
        
        cls = self.counts_to_class_id(cnt)

        return img, cls, cnt
    
    def counts_to_class_id(self, counts):
        if isinstance(counts, torch.Tensor):
            c = counts.detach().cpu().tolist()
        else:
            c = list(counts)

        nz = [i for i, v in enumerate(c) if v > 0]
        if len(nz) != 2:
            raise ValueError(f"Expected exactly 2 nonzero counts, got {len(nz)}: {c}")
        if sum(c) != 10:
            raise ValueError(f"Expected counts to sum to 10, got {sum(c)}: {c}")

        a, b = sorted(nz)
        ca = int(c[a])
        pair_index = PAIR_TO_IDX[(a, b)]

        class_id = pair_index * 9 + (ca - 1)
        return class_id



In [18]:
class Augmentation:
    def __init__(self, p_hflip=0.5, p_vflip=0.5):
        self.p_hflip = p_hflip
        self.p_vflip = p_vflip

    def __call__(self, img, cnt):
        cnt = cnt.clone()

        k = torch.randint(0, 4, (1,)).item()
        if k > 0:
            img = torch.rot90(img, k=-k, dims=[1,2])
            dirs = cnt[2:6]
            dirs = torch.roll(dirs, shifts=k)
            cnt[2:6] = dirs

        if torch.rand(1).item() < self.p_hflip:
            img = torch.flip(img, dims=[2])
            cnt[[3, 5]] = cnt[[5, 3]]
        
        if torch.rand(1).item() < self.p_vflip:
            img = torch.flip(img, dims=[1])
            cnt[[2, 4]] = cnt[[4, 2]]

        return img, cnt


In [19]:
class NeuralNetwork(nn.Module):
    def __init__(self, cls_hidden=256, dropout=0.5):
        super().__init__()

        self.backbone = nn.Sequential(                              # (64, 1, 28, 28)
            nn.Conv2d(1, 8, 3, stride=1, padding=1), nn.ReLU(),     # (64, 8, 28 28)  
            nn.Conv2d(8, 16, 3, stride=1, padding=1), nn.ReLU(),    # (64, 16, 28, 28)
            nn.Conv2d(16, 32, 3, stride=1, padding=1), nn.ReLU(),   # (64, 32, 28, 28)
            nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),   # (64, 64, 28, 28)
            nn.Flatten(start_dim=1),                                # (64, 64 * 28 * 28)
            nn.Linear(64 * 28 * 28, 256), nn.ReLU()
        )

        self.head_cls = nn.Sequential(
            nn.Linear(256, cls_hidden),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(cls_hidden, 135),
            nn.LogSoftmax(dim=1)
        )

        self.head_cnt = nn.Sequential(
            nn.Linear(256, 6)
        )
    
    def forward(self, x):
        x = self.backbone(x)

        cls = self.head_cls(x)  # (64, 135)
        cnt = self.head_cnt(x)  # (64, 6)

        return cls, cnt

In [20]:
def train_epoch(
    net: torch.nn.Module,
    device: torch.device,
    train_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    log_interval: int,
    lambda_cnt: float,
    mode: str,
    verbose: bool = False,
):
    net.train()
    total_loss = total_cls_loss = total_cnt_loss = 0.0
    n_total = 0
    
    for batch_idx, (img, cls_target, cnt_target) in enumerate(train_loader):
        img, cls_target, cnt_target = img.to(device), cls_target.long().to(device), cnt_target.to(device)

        optimizer.zero_grad()

        cls_pred, cnt_pred = net(img)
        
        cls_loss = F.nll_loss(cls_pred, cls_target)
        cnt_loss = F.smooth_l1_loss(cnt_pred, cnt_target)

        if mode == "cls_only":
            loss = cls_loss
        elif mode == "reg_only":
            loss = lambda_cnt * cnt_loss
        else:
            loss = cls_loss + lambda_cnt * cnt_loss

        loss.backward()
        optimizer.step()

        B = len(img)

        total_loss += loss.item() * B
        total_cls_loss += cls_loss.item() * B
        total_cnt_loss += cnt_loss.item() * B
        n_total += B

        if verbose and batch_idx % log_interval == 0:
            done = batch_idx * B
            total = len(train_loader.dataset)
            print(
                f"Train Epoch: {epoch} [{done}/{total} images ({done / total:.0%})]\t"
                + f"Loss: {loss.item():.6f}"
            )

    epoch_loss = total_loss / n_total
    epoch_cls_loss = total_cls_loss / n_total
    epoch_cnt_loss = total_cnt_loss / n_total
    return epoch_loss, epoch_cls_loss, epoch_cnt_loss


In [21]:
def eval_epoch(
    net: torch.nn.Module,
    device: torch.device,
    test_loader: torch.utils.data.DataLoader,
    epoch: int,
    lambda_cnt: float,
    mode: str,
    verbose: bool = False,
):
    net.eval()
    total_loss = total_cls_loss = total_cnt_loss = 0.0
    n_total = 0
    correct = 0

    with torch.no_grad():
        for img, cls_target, cnt_target in test_loader:
            img, cls_target, cnt_target = img.to(device), cls_target.long().to(device), cnt_target.to(device)

            cls_pred, cnt_pred = net(img)
            cls_loss = F.nll_loss(cls_pred, cls_target)
            cnt_loss = F.smooth_l1_loss(cnt_pred, cnt_target)

            if mode == "cls_only":
                loss = cls_loss
            elif mode == "reg_only":
                loss = lambda_cnt * cnt_loss
            else:
                loss = cls_loss + lambda_cnt * cnt_loss

            total_loss += loss.item() * len(img)
            total_cls_loss += cls_loss.item() * len(img)
            total_cnt_loss += cnt_loss.item() * len(img)
            n_total += len(img)

            pred = cls_pred.argmax(dim=1)
            correct += (pred == cls_target).sum().item()

    epoch_loss = total_loss / n_total
    epoch_cls_loss = total_cls_loss / n_total
    epoch_cnt_loss = total_cnt_loss / n_total
    epoch_acc = correct / n_total

    if verbose:
        print(
            f"Eval Epoch: {epoch} | "
            f"acc: {epoch_acc:.4f} | "
            f"loss: {epoch_loss:.4f} | "
            f"cls_loss: {epoch_cls_loss:.4f} | "
            f"cnt_loss: {epoch_cnt_loss:.4f}"
        )
    return epoch_loss, epoch_cls_loss, epoch_cnt_loss, epoch_acc

In [22]:
def create_loaders(root=".", batch_size=64, test_batch_size=1000, device="cpu"):
    num_workers = min(8, os.cpu_count() or 2)
    pin = (device is not None and device.type == "cuda")
    loader_kwargs = dict(
        num_workers=num_workers,
        pin_memory=pin,
        persistent_workers=False,
    )
    if num_workers > 0:
        loader_kwargs["prefetch_factor"] = 4

    train_aug = Augmentation()
    train_full = GSN(root=root, transform_relabel=train_aug)
    test_full = GSN(root=root)

    train_dataset = Subset(train_full, range(0, 9000))
    test_dataset = Subset(test_full, range(9000, 10000))

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs)
    test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, **loader_kwargs)
    return train_loader, test_loader


In [23]:
epochs = 100
lr = 1e-3
batch_size = 64
test_batch_size = 1000

cls_hidden = 256
dropout = 0.3
patience = 10
# lambda_cnt = 1.0

def train_model(
    mode: str,
    lambda_cnt: float,
    log_interval: int = 10,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")

    if device.type == "cuda":
        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    train_loader, test_loader = create_loaders(
        root=".",
        batch_size=64,
        test_batch_size=1000,
        device=device,
    )

    net = NeuralNetwork(cls_hidden, dropout).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    history = {
        "train_loss": [],
        "train_cls_loss": [],
        "train_cnt_loss": [],
        "eval_loss": [],
        "eval_cls_loss": [],
        "eval_cnt_loss": [],
        "eval_acc": [],
    }

    best_eval_loss = float("inf")
    best_acc = 0.0
    best_state = None
    bad_epochs = 0
    best_epoch = 0

    for epoch in range(1, epochs+1):
        train_loss, train_cls_loss, train_cnt_loss = train_epoch(
            net,
            device,
            train_loader,
            optimizer,
            epoch,
            log_interval,
            lambda_cnt,
            mode,
            verbose=False,
        )

        eval_loss, eval_cls_loss, eval_cnt_loss, eval_acc = eval_epoch(
            net,
            device,
            test_loader,
            epoch,
            lambda_cnt,
            mode,
            verbose=True
        )

        history["train_loss"].append(train_loss)
        history["train_cls_loss"].append(train_cls_loss)
        history["train_cnt_loss"].append(train_cnt_loss)
        history["eval_loss"].append(eval_loss)
        history["eval_cls_loss"].append(eval_cls_loss)
        history["eval_cnt_loss"].append(eval_cnt_loss)
        history["eval_acc"].append(eval_acc)

        if eval_loss < best_eval_loss:
            best_eval_loss = eval_loss
            best_acc = eval_acc
            best_state = {k: v.cpu().clone() for k,v in net.state_dict().items()}
            bad_epochs = 0
            best_epoch = epoch
        else:
            bad_epochs += 1
            if bad_epochs >= patience:
                print(
                    f"Early stop at epoch {epoch}. "
                    f"Best val loss: {best_eval_loss:.4f}, "
                    f"best acc: {best_acc:.4f}, "
                    f"best epoch: {best_epoch}"
                )
                break

    if best_state is not None:
        net.load_state_dict(best_state)

    return net, history

In [26]:
net, history = train_model(
    mode = "multitask",
    lambda_cnt=1.0
)


Eval Epoch: 1 | acc: 0.0120 | loss: 5.6244 | cls_loss: 4.3598 | cnt_loss: 1.2646
Eval Epoch: 2 | acc: 0.1510 | loss: 3.7003 | cls_loss: 2.9016 | cnt_loss: 0.7987
Eval Epoch: 3 | acc: 0.2370 | loss: 2.6300 | cls_loss: 2.1747 | cnt_loss: 0.4553
Eval Epoch: 4 | acc: 0.2930 | loss: 2.2671 | cls_loss: 1.8795 | cnt_loss: 0.3876
Eval Epoch: 5 | acc: 0.3230 | loss: 2.1656 | cls_loss: 1.7928 | cnt_loss: 0.3728
Eval Epoch: 6 | acc: 0.3560 | loss: 1.9604 | cls_loss: 1.6431 | cnt_loss: 0.3173
Eval Epoch: 7 | acc: 0.3300 | loss: 2.0548 | cls_loss: 1.7153 | cnt_loss: 0.3395
Eval Epoch: 8 | acc: 0.3650 | loss: 1.8294 | cls_loss: 1.5480 | cnt_loss: 0.2815
Eval Epoch: 9 | acc: 0.3770 | loss: 1.8168 | cls_loss: 1.5408 | cnt_loss: 0.2760
Eval Epoch: 10 | acc: 0.4020 | loss: 1.6715 | cls_loss: 1.4181 | cnt_loss: 0.2534
Eval Epoch: 11 | acc: 0.4150 | loss: 1.7058 | cls_loss: 1.4429 | cnt_loss: 0.2630
Eval Epoch: 12 | acc: 0.4020 | loss: 1.6799 | cls_loss: 1.4325 | cnt_loss: 0.2473
Eval Epoch: 13 | acc: 0.4

In [27]:
save_path = "0535-net.pt"
torch.save(net.state_dict(), save_path)