In [1]:
%load_ext autoreload
%autoreload 2

## Standard libraries
import os
import numpy as np
import random
from PIL import Image
from types import SimpleNamespace
from dotenv import load_dotenv

load_dotenv()

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline

import matplotlib
import seaborn as sns

## PyTorch
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim

# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torchvision.transforms import v2
import torchvision.models as models

import lightning as L
from torch.utils.data import DataLoader

from example_submission import TaskDataset
from torch.utils.data import Dataset
from typing import Tuple

import wandb
from pytorch_lightning.loggers import WandbLogger

In [2]:
class ModifiedDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.ids = dataset.ids
        self.imgs = dataset.imgs
        self.labels = [int(l) for l in dataset.labels]

        self.transform = transform

        self.number_of_classes = len(set(self.labels))
        self.classes_mapping = {label: i for i, label in enumerate(set(self.labels))}

    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if not self.transform is None:
            img = self.transform(img)
        label = self.classes_mapping[self.labels[index]]
        return id_, img, label

    def __len__(self):
        return len(self.ids)

In [14]:
DATASET_PATH = os.getenv("TASK_2_DATA_PUBLIC_PATH")

transform = transforms.Compose([
    transforms.Lambda(lambda x: x.convert("RGB")),
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.2980, 0.2962, 0.2987], std = [0.2886, 0.2875, 0.2889]),
])

all_dataset = torch.load(DATASET_PATH)
all_dataset = ModifiedDataset(all_dataset, transform)

In [15]:
len(set(all_dataset.labels))

50

In [16]:
train_size = int(0.9 * len(all_dataset))
valid_size = len(all_dataset) - train_size

train_dataset, valid_dataset = torch.utils.data.random_split(all_dataset, [train_size, valid_size])
train_dataset.transform = transform
valid_dataset.transform = transform

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [17]:
len(train_dataset), len(valid_dataset)

(11700, 1300)

In [18]:
class PretrainingStealingModule(L.LightningModule):
    def __init__(self, model, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = model
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        id, x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_acc", acc, prog_bar=True, logger=True)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        id, x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, prog_bar=True, logger=True)
        self.log("val_acc", acc, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        id, x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("test_loss", loss)
        self.log("test_acc", acc)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

class Model(nn.Module):
    def __init__(self, model_name, num_classes):
        super().__init__()
        match model_name:
            case "resnet18":
                self.backbone = models.resnet18(pretrained=True)
            case "resnet50":
                self.backbone = models.resnet50(pretrained=True)
            case _:
                raise NotImplementedError

        for param in self.backbone.parameters():
            param.requires_grad = False

        self.representation = nn.Linear(self.backbone.fc.in_features, 1024)
        self.projection = nn.Linear(1024, num_classes)
            
        self.backbone.fc = nn.Identity()
    
    def forward(self, x):
        x = self.backbone(x)
        x = self.representation(x)
        x = self.projection(x)
        return x

# Training on avalaible labeled examples

In [19]:
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    entity="ensemble-ai",
    # Set the wandb project where this run will be logged.
    project="Ensemble AI Hackathon",
    # Track hyperparameters and run metadata.
    config={
     
    },
    group="task-2"
)

wandb_logger = WandbLogger(project="ensemble-ai", entity="Ensemble AI Hackathon")

model = Model(model_name="resnet18", num_classes=50)
lightning_model = PretrainingStealingModule(model)
trainer = L.Trainer(max_epochs=100, accelerator="gpu" if torch.cuda.is_available() else "cpu", logger=wandb_logger)

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\utilities\parsing.py:199: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [12]:
from torchinfo import summary

print(summary(model, input_size=(1, 3, 32, 32)))

Layer (type:depth-idx)                        Output Shape              Param #
Model                                         [1, 50]                   --
├─ResNet: 1-1                                 [1, 512]                  --
│    └─Conv2d: 2-1                            [1, 64, 16, 16]           (9,408)
│    └─BatchNorm2d: 2-2                       [1, 64, 16, 16]           (128)
│    └─ReLU: 2-3                              [1, 64, 16, 16]           --
│    └─MaxPool2d: 2-4                         [1, 64, 8, 8]             --
│    └─Sequential: 2-5                        [1, 64, 8, 8]             --
│    │    └─BasicBlock: 3-1                   [1, 64, 8, 8]             (73,984)
│    │    └─BasicBlock: 3-2                   [1, 64, 8, 8]             (73,984)
│    └─Sequential: 2-6                        [1, 128, 4, 4]            --
│    │    └─BasicBlock: 3-3                   [1, 128, 4, 4]            (230,144)
│    │    └─BasicBlock: 3-4                   [1, 128, 4, 4]        

In [20]:
trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\pytorch_lightning\loggers\wandb.py:391: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Model            | 11.8 M
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
576 K     Trainable params
11.2 M    Non-trainable params
11.8 M    Total params
47.012    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


                                                                           

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 0:  99%|█████████▉| 91/92 [00:03<00:00, 23.17it/s, v_num=gtzn, train_acc=0.453, train_loss=1.650]

  return F.conv2d(input, weight, bias, self.stride,


Epoch 0: 100%|██████████| 92/92 [00:03<00:00, 23.18it/s, v_num=gtzn, train_acc=0.442, train_loss=1.630]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation:   0%|          | 0/11 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s]
Validation DataLoader 0:   9%|▉         | 1/11 [00:00<00:00, 64.53it/s]
Validation DataLoader 0:  18%|█▊        | 2/11 [00:00<00:00, 39.03it/s]
Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 36.48it/s]
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 35.63it/s]
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 35.10it/s]
Validation DataLoader 0:  55%|█████▍    | 6/11 [00:00<00:00, 35.02it/s]
Validation DataLoader 0:  64%|██████▎   | 7/11 [00:00<00:00, 35.15it/s]
Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 34.32it/s]
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 34.00it/s]
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 33.72it/s]
Validation Dat

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 92/92 [00:07<00:00, 12.16it/s, v_num=gtzn, train_acc=0.577, train_loss=1.400, val_loss=1.910, val_acc=0.506]


# Training on quered representations

In [16]:
model.projection = nn.Identity()

print(summary(model, input_size=(1, 3, 32, 32)))

Layer (type:depth-idx)                        Output Shape              Param #
Model                                         [1, 1024]                 --
├─ResNet: 1-1                                 [1, 512]                  --
│    └─Conv2d: 2-1                            [1, 64, 16, 16]           9,408
│    └─BatchNorm2d: 2-2                       [1, 64, 16, 16]           128
│    └─ReLU: 2-3                              [1, 64, 16, 16]           --
│    └─MaxPool2d: 2-4                         [1, 64, 8, 8]             --
│    └─Sequential: 2-5                        [1, 64, 8, 8]             --
│    │    └─BasicBlock: 3-1                   [1, 64, 8, 8]             73,984
│    │    └─BasicBlock: 3-2                   [1, 64, 8, 8]             73,984
│    └─Sequential: 2-6                        [1, 128, 4, 4]            --
│    │    └─BasicBlock: 3-3                   [1, 128, 4, 4]            230,144
│    │    └─BasicBlock: 3-4                   [1, 128, 4, 4]            295,42

In [18]:
class StolenRepresentationDataset(Dataset):
    def __init__(self, ids, imgs, representations, transform=None):
        self.ids = ids
        self.imgs = imgs
        self.representations = representations

        self.transform = transform

    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if not self.transform is None:
            img = self.transform(img)
        representation = self.representations[index]
        return id_, img, representation

    def __len__(self):
        return len(self.ids)

train_ratio = 0.9
val_ratio = 0.1

total_size = len(dataset)
print(total_size)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)

stolen_representation_dataset = StolenRepresentationDataset(
    ids=dataset.ids,
    imgs=dataset.imgs,
    representations=torch.randn(1, 1024).repeat(len(dataset), 1),
    transform=transform
)
train_dataset, val_dataset = torch.utils.data.random_split(stolen_representation_dataset, [train_size, val_size])

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

13000


In [19]:
class RepresentationStealingModule(L.LightningModule):
    def __init__(self, model, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = model
        self.criterion = nn.MSELoss()
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        id, x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        id, x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log("val_loss", loss, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        id, x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

lightning_model = RepresentationStealingModule(model)
trainer = L.Trainer(max_epochs=100, accelerator="gpu" if torch.cuda.is_available() else "cpu", logger=wandb_logger)

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\utilities\parsing.py:199: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [20]:
trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\pytorch_lightning\loggers\wandb.py:391: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:653: Checkpoint directory .\ensemble-ai\z6tcoh87\checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type    | Params
--------------------------------------
0 | model     | Model   | 11.7 M
1 | criterion | MSELoss | 0     
--------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.807    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


                                                                           

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\loops\fit_loop.py:298: The number of training batches (46) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 46/46 [00:05<00:00,  8.30it/s, v_num=oh87, train_loss=0.0148] 
Validation: |          | 0/? [00:00<?, ?it/s]
Validation:   0%|          | 0/6 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s]
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 71.32it/s]
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 27.53it/s]
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 22.57it/s]
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 21.05it/s]
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 20.14it/s]
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 22.70it/s]
Epoch 1: 100%|██████████| 46/46 [00:04<00:00,  9.25it/s, v_num=oh87, train_loss=0.00108, val_loss=0.0482]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation:   0%|          | 0/6 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s]
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 3

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


# Submission

In [None]:
# model.projection = nn.Identity()

path = 'task_2_submission_new.onnx'

torch.onnx.export(
    lightning_model.model,
    torch.randn(1, 3, 32, 32),
    path,
    export_params=True,
    input_names=["x"],
    output_names=["output"],
    dynamic_axes={
        "x": {0: "batch_size"},  # Make the batch dimension dynamic
        "output": {0: "batch_size"}
    }
)

In [None]:
import onnxruntime as ort

with open(path, "rb") as f:
    stolen_model = f.read()
    try:
        stolen_model = ort.InferenceSession(stolen_model)
    except Exception as e:
        raise Exception(f"Invalid model, {e=}")
    try:
        out = stolen_model.run(
            None, {"x": np.random.randn(1, 3, 32, 32).astype(np.float32)}
        )[0][0]
    except Exception as e:
        raise Exception(f"Some issue with the input, {e=}")
    assert out.shape == (1024,), "Invalid output shape"

1024
