In [None]:
import torch
import torchvision
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import timm  

device = "cuda" if torch.cuda.is_available() else "cpu"

class NetworkWrapper(torch.nn.Module):
    def __init__(self, network, preprocess_fn):
        super(NetworkWrapper, self).__init__()
        self.preprocess_fn = preprocess_fn
        self.network = network
        self.network.eval()

    def forward(self, x):
        x = self.preprocess_fn(x)
        x = self.network(x)
        return x

class Visualization(torch.nn.Module):
    def __init__(self, h, w):
        super(Visualization, self).__init__()
        self.__data = torch.nn.Parameter(torch.randn(1, 3, h, w))
        self.out_h = h
        self.out_w = w

    def __augment(self, x, batch_size):
        x = torch.cat([x] * batch_size, dim=0)
        x = torchvision.transforms.RandomResizedCrop([self.out_h, self.out_w], antialias=True)(x)
        return x

    def __reparameterize(self, x):
        x = torch.nn.functional.sigmoid(x)
        return x

    def forward(self, batch_size):
        x = self.__data
        x = self.__reparameterize(x)
        x = self.__augment(x, batch_size)
        return x

    def to_img(self):
        with torch.no_grad():
            x = self.__data
            x = self.__reparameterize(x)
            x = x.squeeze().cpu().numpy()
            x = np.transpose(x, (1, 2, 0))
            x = np.clip(x, 0, 1)
            pil_img = Image.fromarray((x * 255).astype(np.uint8))
            return pil_img

net = timm.create_model('vit_base_patch16_224', pretrained=True).to(device)
preprocess_fn = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
model = NetworkWrapper(net, preprocess_fn).to(device)

vis = Visualization(224, 224).to(device)  

optimizer = torch.optim.AdamW(params=vis.parameters(), lr=0.2)


targets = [torch.randn(model.network.num_classes) for _ in range(10)]
targets = [t / t.norm() for t in targets]  
targets = torch.stack(targets).to(device)


for i, target in enumerate(targets):
    for aug in [True, False]:
        vis.__data = torch.nn.Parameter(torch.randn(1, 3, 224, 224))  
        for _ in range(100):  
            optimizer.zero_grad()
            output = model(vis(8 if aug else 1))  
            loss = torch.nn.functional.mse_loss(output, target.repeat(8 if aug else 1, 1))
            loss.backward()
            optimizer.step()

        
        img = vis.to_img()
        img_name = f"class_{i}_{'aug' if aug else 'no_aug'}.png"
        img.save(img_name)
        plt.imshow(img)
        plt.title(img_name)
        plt.show()