In [2]:
%load_ext autoreload
%autoreload 2

## Standard libraries
import os
import numpy as np
import random
from PIL import Image
from types import SimpleNamespace
from dotenv import load_dotenv

load_dotenv()

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## PyTorch
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim

# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torchvision.models as models

import lightning as L
from torch.utils.data import DataLoader

from example_submission import TaskDataset
from torch.utils.data import Dataset
from typing import Tuple

import wandb
from pytorch_lightning.loggers import WandbLogger

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  set_matplotlib_formats('svg', 'pdf') # For export


In [3]:
class ModifedDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.ids = dataset.ids
        self.imgs = dataset.imgs
        self.labels = [int(l) for l in dataset.labels]

        self.transform = transform

        self.number_of_classes = len(set(self.labels))
        self.classes_mapping = {label: i for i, label in enumerate(set(self.labels))}

    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if not self.transform is None:
            img = self.transform(img)
        label = self.classes_mapping[self.labels[index]]
        return id_, img, label

    def __len__(self):
        return len(self.ids)

In [None]:
DATASET_PATH = os.getenv("TASK_2_DATA_PUBLIC_PATH")

transform = transforms.Compose([
    # transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),
    transforms.Lambda(lambda x: x.convert("RGB")),
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.2980, 0.2962, 0.2987], std = [0.2886, 0.2875, 0.2889]),
])

dataset = torch.load(DATASET_PATH)
dataset = ModifedDataset(dataset, transform)

In [5]:
len(set(dataset.labels))

50

In [6]:
# Define the split ratios
train_ratio = 0.9
val_ratio = 0.1
# test_ratio = 0.1

# Calculate the lengths for each split
total_size = len(dataset)
print(total_size)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
# test_size = total_size - train_size - val_size

# Split the dataset
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create DataLoader objects for each split
batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

13000


In [7]:
len(train_dataset), len(val_dataset)

(11700, 1300)

In [8]:
for batch in train_loader:
    print(batch)
    break

[tensor([ 81069,  93230,  91806,  11163, 237129,  40383, 172721, 138638,  67922,
         49876,  99193,  21352,  37180, 232572, 131143, 233674,  44357, 165986,
         65188, 151928, 136982, 266259, 221028, 147033, 223548, 160966,  48469,
        185546, 244641,  81241, 203996, 252041, 172261, 245538, 281866, 163474,
          3501,   2295, 261550, 153586, 249527, 272478, 198213, 175777,  30250,
        231571, 218944, 275685, 274077, 285374, 118181,  78431,  33799, 100316,
        221527, 252750,   9839, 145117, 175818, 239575, 299324, 118961,   2387,
        157652,  34223, 178661,  49503, 293469, 125614, 136103, 107646, 160797,
         50560,  20122, 151793,  51965, 114630,  63734, 117012,  32670, 189976,
        214586,  72128, 305752, 167443, 160238, 180360, 208440, 105839, 192273,
        242533,  61617, 287411, 231169,  95383,  34146, 185633, 300503, 155695,
        136631, 149361, 237003, 203607, 213870,  65783, 179807,  98263,  69158,
         73641, 207996, 118734, 216817,

In [9]:
class StealingModule(L.LightningModule):
    def __init__(self, model, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = model
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        id, x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_acc", acc, prog_bar=True, logger=True)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        id, x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, prog_bar=True, logger=True)
        self.log("val_acc", acc, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        id, x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("test_loss", loss)
        self.log("test_acc", acc)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

class Model(nn.Module):
    def __init__(self, model_name, num_classes):
        super().__init__()
        match model_name:
            case "resnet18":
                self.backbone = models.resnet18(pretrained=True)
            case _:
                raise NotImplementedError

        self.representation = nn.Linear(self.backbone.fc.in_features, 1024)
        self.projection = nn.Linear(1024, num_classes)
            
        self.backbone.fc = nn.Identity()
    
    def forward(self, x):
        x = self.backbone(x)
        x = self.representation(x)
        x = self.projection(x)
        return x

In [None]:
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    entity="ensemble-ai",
    # Set the wandb project where this run will be logged.
    project="Ensemble AI Hackathon",
    # Track hyperparameters and run metadata.
    config={
     
    },
    group="task-2"
)

wandb_logger = WandbLogger(project="ensemble-ai", entity="Ensemble AI Hackathon")

model = Model(model_name="resnet18", num_classes=50)
lightning_model = StealingModule(model)
trainer = L.Trainer(max_epochs=100, accelerator="gpu" if torch.cuda.is_available() else "cpu", logger=wandb_logger)

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: arekpaterak to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\utilities\parsing.py:199: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [11]:
from torchinfo import summary

print(summary(model, input_size=(1, 3, 32, 32)))

Layer (type:depth-idx)                        Output Shape              Param #
Model                                         [1, 50]                   --
├─ResNet: 1-1                                 [1, 512]                  --
│    └─Conv2d: 2-1                            [1, 64, 16, 16]           9,408
│    └─BatchNorm2d: 2-2                       [1, 64, 16, 16]           128
│    └─ReLU: 2-3                              [1, 64, 16, 16]           --
│    └─MaxPool2d: 2-4                         [1, 64, 8, 8]             --
│    └─Sequential: 2-5                        [1, 64, 8, 8]             --
│    │    └─BasicBlock: 3-1                   [1, 64, 8, 8]             73,984
│    │    └─BasicBlock: 3-2                   [1, 64, 8, 8]             73,984
│    └─Sequential: 2-6                        [1, 128, 4, 4]            --
│    │    └─BasicBlock: 3-3                   [1, 128, 4, 4]            230,144
│    │    └─BasicBlock: 3-4                   [1, 128, 4, 4]            295,42

In [12]:
trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\pytorch_lightning\loggers\wandb.py:391: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Model            | 11.8 M
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
11.8 M    Trainable params
0         Non-trainable params
11.8 M    Total params
47.012    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,


                                                                           

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\loops\fit_loop.py:298: The number of training batches (46) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0:   0%|          | 0/46 [00:00<?, ?it/s] 

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 0: 100%|██████████| 46/46 [00:03<00:00, 12.41it/s, v_num=oh87, train_acc=0.550, train_loss=1.330]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation:   0%|          | 0/6 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s]
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 56.80it/s]
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 33.98it/s]
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 30.07it/s]
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 29.71it/s]
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 29.01it/s]
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 29.02it/s]
Epoch 1: 100%|██████████| 46/46 [00:03<00:00, 14.45it/s, v_num=oh87, train_acc=0.600, train_loss=1.160, val_loss=1.390, val_acc=0.526]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation:   0%|          | 0/6 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s]
Validation DataLo

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 46/46 [00:05<00:00,  9.13it/s, v_num=oh87, train_acc=0.994, train_loss=0.0116, val_loss=2.770, val_acc=0.625]


In [61]:
model.projection = nn.Identity()

path = 'task_2_submission.onnx'

torch.onnx.export(
    lightning_model.model,
    torch.randn(1, 3, 32, 32),
    path,
    export_params=True,
    input_names=["x"],
    output_names=["output"],
    dynamic_axes={
        "x": {0: "batch_size"},  # Make the batch dimension dynamic
        "output": {0: "batch_size"}
    }
)

In [62]:
import onnxruntime as ort

with open(path, "rb") as f:
    stolen_model = f.read()
    try:
        stolen_model = ort.InferenceSession(stolen_model)
    except Exception as e:
        raise Exception(f"Invalid model, {e=}")
    try:
        out = stolen_model.run(
            None, {"x": np.random.randn(1, 3, 32, 32).astype(np.float32)}
        )[0][0]
    except Exception as e:
        raise Exception(f"Some issue with the input, {e=}")
    assert out.shape == (1024,), "Invalid output shape"

print(len(out))

1024
