In [None]:
import requests
import torch
import os
os.chdir("/Users/mszszczepanowski/repos/ensemble-ai2024/modelstealing")
from taskdataset import TaskDataset
from torchvision import transforms as T
import matplotlib.pyplot as plt
import json
from tqdm import tqdm
import time
import pickle
import numpy as np

In [None]:
TEAM_TOKEN = "8J40ASDQOjfeeSKL"
SERVER_URL = "http://34.71.138.79:9090"

In [None]:
class NoiseDetector:
    def __init__(self, images):
        # images -> list[images as a tensor]
        self.images = images
        self.base_img = self.images[0]

        self.base_representation = self.invoke_endpoint(image=self.base_img)

        self.mse_history = []
        self.images_representations = []
        self.ids = []

        self.images_to_save = []

    def save_tensor_as_png_file(self, img_tensor, path_to_png_file):
        img_pil = T.ToPILImage()(img_tensor)
        img_pil.save(path_to_png_file)

    def calculate_mse(self, vector1, vector2):
        sum_of_squares = sum((v1 - v2) ** 2 for v1, v2 in zip(vector1, vector2))
        return sum_of_squares / len(vector1)

    def invoke_endpoint(self, image, path_to_png_file="img.png"):
        self.save_tensor_as_png_file(image, path_to_png_file=path_to_png_file)
        ENDPOINT = "/modelstealing"
        URL = SERVER_URL + ENDPOINT

        with open(path_to_png_file, "rb") as img_file:
            response = requests.get(
                URL, files={"file": img_file}, headers={"token": TEAM_TOKEN}
            )

            if response.status_code == 200:
                return json.loads(response.content.decode())["representation"]
            else:
                raise Exception(f"Request failed. Status code: {response.status_code}, content: {response.content}")
    
    def run(self):
        base_representation_mse = 0
        pbar = tqdm(enumerate(self.images))
        for i, img in pbar:
            pbar.set_description(f"MSE: {base_representation_mse}")
            try:
                img_representation = self.invoke_endpoint(image=img)

                # Get MSE between first representation of the base image vs the nw one
                if i % 5 == 0:
                    new_base_img_representation = self.invoke_endpoint(image=self.base_img)
                    base_representation_mse = self.calculate_mse(new_base_img_representation, self.base_representation)
            except Exception as e:
                print(e)
                time.sleep(3)
                continue
            self.images_representations.append(torch.from_numpy(np.array(img_representation)))
            self.mse_history.append(base_representation_mse)
            self.images_to_save.append(img)
            self.ids.append(i)
            if i % 100 == 0:
                stacked_reprs = torch.stack(self.images_representations)
                stacked_imgs = torch.stack(self.images_to_save)

                torch.save(stacked_reprs, 'stacked_reprs_new5.pt')
                torch.save(stacked_imgs, 'stacked_imgs_new5.pt')
                print(f"stacked_images shape -> {stacked_imgs.shape}")
                print(f"stacked_reprs shape -> {stacked_reprs.shape}")

In [None]:
# dataset = torch.load("xd_images.pt")
dataset = torch.load("ModelStealingPub.pt")
all_images = [T.ToTensor()(img.convert("RGB")) for _, img, _ in dataset]
all_images = all_images[8000:]
len(all_images)

In [None]:
noise_detector = NoiseDetector(images=all_images)

In [None]:
noise_detector.run()

# Plotting MSE history

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(noise_detector.mse_history, marker='o', linestyle='-', color='b')
plt.title('History of Mean Square Error (MSE)')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.grid(True)
plt.show()