In [2]:
# !wget https://github.com/marcin119a/data/raw/refs/heads/main/data_gsn.zip
# !unzip data_gsn.zip &> /dev/null
# !rm data_gsn.zip

In [3]:
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import pandas as pd
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt
from torch import Tensor
from torchvision import transforms
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)


<torch._C.Generator at 0x7f837177e6f0>

In [4]:
class GSN(Dataset):
    def __init__(self, root, transform=None, transform_relabel=None):
        self.data_dir = os.path.join(root, "data")
        self.transform = transform
        self.transform_relabel = transform_relabel

        df = pd.read_csv(os.path.join(self.data_dir, "labels.csv"))
        self.names = df["name"].tolist()
        cols = ["squares", "circles", "up", "right", "down", "left"]
        self.labels = torch.tensor(df[cols].values, dtype=torch.float32)

    def __len__(self):
        return len(self.names)
    
    def __getitem__(self, index):
        name = self.names[index]
        img_path = os.path.join(self.data_dir, name)

        img = Image.open(img_path).convert("L")
        img = transforms.ToTensor()(img)
        
        if self.transform:
            img = self.transform(img)

        cnt = self.labels[index]

        if self.transform_relabel:
            img, cnt = self.transform_relabel(img, cnt)
        
        cls = self.counts_to_class_id(cnt)

        return img, cls, cnt
    
    def counts_to_class_id(self, counts):
        PAIRS = [(i, j) for i in range(6) for j in range(i + 1, 6)]
        if isinstance(counts, torch.Tensor):
            c = counts.detach().cpu().tolist()
        else:
            c = list(counts)

        nz = [i for i, v in enumerate(c) if v > 0]
        if len(nz) != 2:
            raise ValueError(f"Expected exactly 2 nonzero counts, got {len(nz)}: {c}")
        if sum(c) != 10:
            raise ValueError(f"Expected counts to sum to 10, got {sum(c)}: {c}")

        a, b = sorted(nz)
        ca = int(c[a])
        pair_index = PAIRS.index((a, b))

        class_id = pair_index * 9 + (ca - 1)
        return class_id


In [5]:
class Augumentation:
    def __init__(self, p_hflip=0.5, p_vflip=0.5):
        self.p_hflip = p_hflip
        self.p_vflip = p_vflip

    def __call__(self, img, cnt):
        cnt = cnt.clone()

        k = torch.randint(0, 4, (1,)).item()
        if k > 0:
            img = torch.rot90(img, k=-k, dims=[1,2])
            dirs = cnt[2:6]
            dirs = torch.roll(dirs, shifts=k)
            cnt[2:6] = dirs

        if torch.rand(1).item() < self.p_hflip:
            img = torch.flip(img, dims=[2])
            cnt[[3, 5]] = cnt[[5, 3]]
        
        if torch.rand(1).item() < self.p_vflip:
            img = torch.flip(img, dims=[1])
            cnt[[2, 4]] = cnt[[4, 2]]

        return img, cnt


In [6]:
class NeuralNetwork(nn.Module):
    def __init__(self, cls_hidden=256, dropout=0.5):
        super().__init__()

        self.backbone = nn.Sequential(                              # (64, 1, 28, 28)
            nn.Conv2d(1, 8, 3, stride=1, padding=1), nn.ReLU(),     # (64, 8, 28 28)  
            nn.Conv2d(8, 16, 3, stride=1, padding=1), nn.ReLU(),    # (64, 16, 28, 28)
            nn.Conv2d(16, 32, 3, stride=1, padding=1), nn.ReLU(),   # (64, 32, 28, 28)
            nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),   # (64, 64, 28, 28)
            nn.Flatten(start_dim=1),                                # (64, 64 * 28 * 28)
            nn.Linear(64 * 28 * 28, 256), nn.ReLU()
        )

        self.head_cls = nn.Sequential(
            nn.Linear(256, cls_hidden),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            # nn.Linear(cls_hidden, cls_hidden),
            # nn.ReLU(),
            # nn.Dropout(p=dropout),
            nn.Linear(cls_hidden, 135),
            nn.LogSoftmax(dim=1)
        )

        self.head_cnt = nn.Sequential(
            nn.Linear(256, 6)
        )
    
    def forward(self, x):
        x = self.backbone(x)

        cls = self.head_cls(x)  # (64, 135)
        cnt = self.head_cnt(x)  # (64, 6)

        return cls, cnt

In [7]:
def train_epoch(
    net: torch.nn.Module,
    device: torch.device,
    train_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    log_interval: int,
    lambda_cnt: float,
    verbose: bool = False,
) -> None:
    net.train()
    total_loss = 0.0
    total_cls_loss = 0.0
    total_cnt_loss = 0.0
    n_total = 0
    
    for batch_idx, (img, cls_target, cnt_target) in enumerate(train_loader):
        img, cls_target, cnt_target = img.to(device), cls_target.long().to(device), cnt_target.to(device)

        optimizer.zero_grad()

        cls_pred, cnt_pred = net(img)
        
        cls_loss = F.nll_loss(cls_pred, cls_target)
        cnt_loss = F.smooth_l1_loss(cnt_pred, cnt_target)
        loss = cls_loss + lambda_cnt * cnt_loss

        loss.backward()
        optimizer.step()

        B = len(img)

        total_loss += loss.item() * B
        total_cls_loss += cls_loss.item() * B
        total_cnt_loss += cnt_loss.item() * B
        n_total += B

        if verbose and batch_idx % log_interval == 0:
            done = batch_idx * B
            total = len(train_loader.dataset)
            print(
                f"Train Epoch: {epoch} [{done}/{total} images ({done / total:.0%})]\t"
                + f"Loss: {loss.item():.6f}"
            )

    epoch_loss = total_loss / n_total
    epoch_cls_loss = total_cls_loss / n_total
    epoch_cnt_loss = total_cnt_loss / n_total
    return epoch_loss, epoch_cls_loss, epoch_cnt_loss


In [8]:
def eval_epoch(
    net: torch.nn.Module,
    device: torch.device,
    test_loader: torch.utils.data.DataLoader,
    epoch: int,
    lambda_cnt: float,
    verbose: bool = False,
):
    net.eval()
    total_loss = 0.0
    total_cls_loss = 0.0
    total_cnt_loss = 0.0
    n_total = 0
    correct = 0


    with torch.no_grad():
        for img, cls_target, cnt_target in test_loader:
            img, cls_target, cnt_target = img.to(device), cls_target.long().to(device), cnt_target.to(device)

            cls_pred, cnt_pred = net(img)
            cls_loss = F.nll_loss(cls_pred, cls_target)
            cnt_loss = F.smooth_l1_loss(cnt_pred, cnt_target)
            loss = cls_loss + lambda_cnt * cnt_loss

            total_loss += loss.item() * len(img)
            total_cls_loss += cls_loss.item() * len(img)
            total_cnt_loss += cnt_loss.item() * len(img)
            n_total += len(img)

            pred = cls_pred.argmax(dim=1)
            correct += (pred == cls_target).sum().item()

    epoch_loss = total_loss / n_total
    epoch_cls_loss = total_cls_loss / n_total
    epoch_cnt_loss = total_cnt_loss / n_total
    epoch_acc = correct / n_total

    if verbose:
        print(f"Eval Epoch: {epoch} accuracy: {epoch_acc:.4f} epoch_loss: {epoch_loss:.4f} epoch_cls_loss: {epoch_cls_loss:.4f} epoch_cnt_loss: {epoch_cnt_loss:.4f}\n")
    return epoch_loss, epoch_cls_loss, epoch_cnt_loss, epoch_acc

In [12]:
epochs = 100
lr = 1e-3
log_interval = 10
batch_size = 64
test_batch_size = 1000
lambda_cnt = 1.0
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")

if device.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

num_workers = min(8, os.cpu_count() or 2)
pin = (device.type == "cuda")
train_kwargs = dict(
    num_workers=num_workers,
    pin_memory=pin,
    persistent_workers=False,
)
if num_workers > 0:
    train_kwargs["prefetch_factor"] = 4

train_augumentation = Augumentation()

train_dataset = GSN(root=".", transform_relabel=train_augumentation)
test_dataset = GSN(root=".")

train_dataset = Subset(train_dataset, range(0, 9000))
test_dataset = Subset(test_dataset, range(9000, 10000))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **train_kwargs)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, **train_kwargs)

net = NeuralNetwork(cls_hidden=256, dropout=0.3).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

train_losses = []
eval_losses = []

best_eval_loss = float("inf")
best_accuracy = 0.0
best_state = None
bad_epochs = 0
bad_epochs_patience = 10
optimal_epochs = epochs

for epoch in range(1, epochs+1):
    train_loss, train_cls_loss, train_cnt_loss = train_epoch(
        net,
        device,
        train_loader,
        optimizer,
        epoch,
        log_interval,
        lambda_cnt,
        verbose=False,
    )

    eval_loss, eval_cls_loss, eval_cnt_loss, eval_acc = eval_epoch(
        net,
        device,
        test_loader,
        epoch,
        lambda_cnt,
        verbose=True
    )

    if eval_loss < best_eval_loss:
        best_eval_loss = eval_loss
        best_accuracy = eval_acc
        best_state = {k: v.cpu().clone() for k,v in net.state_dict().items()}
        bad_epochs = 0
        optimal_epochs = epoch
    else:
        bad_epochs += 1
        if bad_epochs >= bad_epochs_patience:
            print(f"Early stop at epoch {epoch}. Best eval loss: {best_eval_loss:.4f}, Best accuracy = {best_accuracy} Optimal epochs: {optimal_epochs}")
            break
    
    train_losses.append(train_loss)
    eval_losses.append(eval_loss)

if best_state is not None:
    net.load_state_dict(best_state)


Eval Epoch: 1 accuracy: 0.0260 epoch_loss: 5.4615 epoch_cls_loss: 4.2501 epoch_cnt_loss: 1.2114

Eval Epoch: 2 accuracy: 0.1900 epoch_loss: 3.1856 epoch_cls_loss: 2.5819 epoch_cnt_loss: 0.6038

Eval Epoch: 3 accuracy: 0.2420 epoch_loss: 2.5431 epoch_cls_loss: 2.0906 epoch_cnt_loss: 0.4525

Eval Epoch: 4 accuracy: 0.2860 epoch_loss: 2.2471 epoch_cls_loss: 1.8727 epoch_cnt_loss: 0.3743

Eval Epoch: 5 accuracy: 0.3160 epoch_loss: 1.9976 epoch_cls_loss: 1.6695 epoch_cnt_loss: 0.3281

Eval Epoch: 6 accuracy: 0.3340 epoch_loss: 2.0079 epoch_cls_loss: 1.6863 epoch_cnt_loss: 0.3216

Eval Epoch: 7 accuracy: 0.3550 epoch_loss: 1.8379 epoch_cls_loss: 1.5485 epoch_cnt_loss: 0.2894

Eval Epoch: 8 accuracy: 0.3490 epoch_loss: 1.9242 epoch_cls_loss: 1.5970 epoch_cnt_loss: 0.3272

Eval Epoch: 9 accuracy: 0.3860 epoch_loss: 1.7287 epoch_cls_loss: 1.4614 epoch_cnt_loss: 0.2672

Eval Epoch: 10 accuracy: 0.3620 epoch_loss: 1.8894 epoch_cls_loss: 1.5862 epoch_cnt_loss: 0.3032

Eval Epoch: 11 accuracy: 0.40

In [None]:
save_path = "0512-net.pt"
torch.save(net.state_dict(), save_path)