In [None]:
%load_ext autoreload
%autoreload 2

## Standard libraries
import os
import numpy as np
import random
from PIL import Image
from types import SimpleNamespace
from dotenv import load_dotenv

load_dotenv()

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## PyTorch
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim

# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torchvision.models as models

import lightning as L
from torch.utils.data import DataLoader

from example_submission import TaskDataset
from torch.utils.data import Dataset
from typing import Tuple

  set_matplotlib_formats('svg', 'pdf') # For export


In [None]:
class ModifedDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.ids = dataset.ids
        self.imgs = dataset.imgs
        self.labels = [int(l) for l in dataset.labels]

        self.transform = transform

        self.number_of_classes = len(set(self.labels))
        self.classes_mapping = {label: i for i, label in enumerate(set(self.labels))}

    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if not self.transform is None:
            img = self.transform(img)
        label = self.classes_mapping[self.labels[index]]
        return id_, img, label

    def __len__(self):
        return len(self.ids)

In [None]:
DATASET_PATH = os.getenv("TASK_2_DATA_PUBLIC_PATH")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),
    transforms.Normalize(mean = [0.2980, 0.2962, 0.2987], std = [0.2886, 0.2875, 0.2889]),
])

dataset = torch.load(DATASET_PATH)
dataset = ModifedDataset(dataset, transform)

In [None]:
len(set(dataset.labels))

50

In [None]:
# Define the split ratios
train_ratio = 0.9
val_ratio = 0.1
# test_ratio = 0.1

# Calculate the lengths for each split
total_size = len(dataset)
print(total_size)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
# test_size = total_size - train_size - val_size

# Split the dataset
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create DataLoader objects for each split
batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

13000


In [59]:
len(train_dataset), len(val_dataset)

(11700, 1300)

In [45]:
for batch in train_loader:
    print(batch)
    break

[tensor([125042,  37746,  48298,    629, 112550,  78279,  26582, 125400, 235704,
        127706,  18051, 229673, 172382,  38813, 208684,   7682,  35155, 234306,
        227129, 219950, 211559, 292212,  22899, 167549, 281195, 124263,  98414,
        237129, 215098, 180095, 287061, 133070, 173987, 130627,  43295,  54633,
        225780, 179117,   1154, 164865, 157891,   8584,  79262, 162857, 266935,
        186996,  14334,  27297, 271741, 178680, 292850, 144421, 146437, 288145,
        113014,  20130, 118079,  12462, 112512, 305691, 210137,  88286, 206788,
        121070, 262805, 140369, 121173,  48630,  20624,    752,  86121,  31053,
         82436, 150115, 208901, 204892,  92799, 123427, 155076, 141454, 183893,
        195107,  48702,  33514,  79529, 195861, 288989,  35421,  67165, 162220,
        288087, 269807, 254197, 183590, 214496,  20971, 254072,  18886, 127074,
        130023,  31756, 115778,  93831, 136154, 219133,   5140, 303393, 258928,
        169791, 226007, 213496, 120529,

In [46]:
class StealingModule(L.LightningModule):
    def __init__(self, model, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = model
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        id, x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_acc", acc, prog_bar=True, logger=True)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        id, x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, prog_bar=True, logger=True)
        self.log("val_acc", acc, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        id, x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("test_loss", loss)
        self.log("test_acc", acc)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

class Model(nn.Module):
    def __init__(self, model_name, num_classes):
        super().__init__()
        match model_name:
            case "resnet18":
                self.backbone = models.resnet18(pretrained=True)
            case _:
                raise NotImplementedError

        self.representation = nn.Linear(self.backbone.fc.in_features, 1024)
        self.projection = nn.Linear(1024, num_classes)
            
        self.backbone.fc = nn.Identity()
    
    def forward(self, x):
        x = self.backbone(x)
        x = self.representation(x)
        x = self.projection(x)
        return x

In [47]:
model = Model(model_name="resnet18", num_classes=50)
lightning_model = StealingModule(model)
trainer = L.Trainer(max_epochs=25, accelerator="gpu" if torch.cuda.is_available() else "cpu", logger=False)

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\utilities\parsing.py:199: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [61]:
from torchinfo import summary

print(summary(model))

AttributeError: 'bytes' object has no attribute 'parameters'

In [48]:
trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:653: Checkpoint directory d:\Code\ensemble-ai\task_2\checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Model            | 11.8 M
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
11.8 M    Trainable params
0         Non-trainable params
11.8 M    Total params
47.012    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\core\module.py:507: You called `self.log('val_loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 17.23it/s]

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\core\module.py:507: You called `self.log('val_acc', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`


                                                                           

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 0:   2%|▏         | 1/41 [00:00<00:06,  6.58it/s, train_acc=0.0586, train_loss=3.980]

d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\core\module.py:507: You called `self.log('train_acc', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
d:\ProgramFiles\Miniconda3\envs\thesis\Lib\site-packages\lightning\pytorch\core\module.py:507: You called `self.log('train_loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`


Epoch 24: 100%|██████████| 41/41 [00:05<00:00,  7.29it/s, train_acc=0.963, train_loss=0.107, val_loss=2.430, val_acc=0.578] 

`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 24: 100%|██████████| 41/41 [00:06<00:00,  6.54it/s, train_acc=0.963, train_loss=0.107, val_loss=2.430, val_acc=0.578]


In [41]:
model.projection = nn.Identity()

path = 'task_2_submission.onnx'

torch.onnx.export(
    lightning_model.model,
    torch.randn(1, 3, 32, 32),
    path,
    export_params=True,
    input_names=["x"],
)

In [51]:
import onnxruntime as ort

with open(path, "rb") as f:
    model = f.read()
    try:
        stolen_model = ort.InferenceSession(model)
    except Exception as e:
        raise Exception(f"Invalid model, {e=}")
    try:
        out = stolen_model.run(
            None, {"x": np.random.randn(1, 3, 32, 32).astype(np.float32)}
        )[0][0]
    except Exception as e:
        raise Exception(f"Some issue with the input, {e=}")
    assert out.shape == (1024,), "Invalid output shape"

print(len(out))

1024
