In [3]:
import torch
from augmentation import Clip, Tile, Jitter, RepeatBatch, ColorJitter
from augmentation.pre import GaussianNoise
from hooks.transformer.vit import ViTAttHookHolder, ViTGeLUHook
from inversion import ImageNetVisualizer
from inversion.utils import new_init
from loss import LossArray, TotalVariation
from loss.image_net import ViTFeatHook, ViTEnsFeatHook
from model import model_library
from saver import ExperimentSaver
from utils import exp_starter_pack

def run_main(layer, feature, grid, network, tv):
    model, image_size, _, _ = model_library[network]()
    #import pdb; pdb.set_trace()
    saver = ExperimentSaver(f'VisL{layer}_F{feature}_N{network}_TV{tv}', save_id=True, disk_saver=True)

    loss = LossArray()
    loss += ViTEnsFeatHook(ViTGeLUHook(model, sl=slice(layer, layer + 1)), key='high', feat=feature, coefficient=1)
    loss += TotalVariation(2, image_size, coefficient=0.0005 * tv)

    pre, post = torch.nn.Sequential(RepeatBatch(8), ColorJitter(8, shuffle_every=True),
                                    GaussianNoise(8, True, 0.5, 400), Tile(image_size // image_size), Jitter()), Clip()

    image = new_init(image_size, 1)
    visualizer = ImageNetVisualizer(loss, None, pre, post, print_every=10, lr=0.1, steps=400, save_every=100)
    image.data = visualizer(image)
    saver.save(image, 'final')

In [29]:
run_main(layer=0,feature=3071,grid=0,network=35,tv=0.1)
'''
layer in [0,11]
feature for layer 0 in [0,3071]
'''
0

Loaded pretrained weights.
#i	Loss	VTE	TV
0	-0.01	-0.05(-0.05)	0.04(777)
10	-0.47	-0.5(-0.5)	0.03(570)
20	-0.22	-0.25(-0.25)	0.03(596)
30	-0.31	-0.33(-0.33)	0.02(491)
40	-0.3	-0.32(-0.32)	0.03(529)
50	-0.38	-0.4(-0.4)	0.02(439)
60	-0.6	-0.62(-0.62)	0.02(405)
70	-0.27	-0.29(-0.29)	0.02(456)
80	-0.29	-0.31(-0.31)	0.02(451)
90	-0.56	-0.58(-0.58)	0.02(365)
100	-0.25	-0.27(-0.27)	0.02(438)
110	-0.64	-0.66(-0.66)	0.02(429)
120	-0.9	-0.92(-0.92)	0.03(502)
130	-0.14	-0.16(-0.16)	0.02(385)
140	-0.69	-0.71(-0.71)	0.02(354)
150	-0.21	-0.23(-0.23)	0.02(436)
160	-0.45	-0.47(-0.47)	0.02(439)
170	-0.72	-0.74(-0.74)	0.02(411)
180	-0.41	-0.43(-0.43)	0.02(386)
190	-0.59	-0.61(-0.61)	0.02(358)
200	-0.52	-0.53(-0.53)	0.02(337)
210	-0.39	-0.4(-0.4)	0.02(358)
220	-0.62	-0.64(-0.64)	0.02(489)
230	-0.56	-0.58(-0.58)	0.02(388)
240	-1.05	-1.06(-1.06)	0.02(311)
250	-0.53	-0.54(-0.54)	0.01(299)
260	-0.63	-0.66(-0.66)	0.02(437)
270	-0.47	-0.49(-0.49)	0.02(431)
280	-0.34	-0.36(-0.36)	0.02(390)
290	-0.46	-0.48(-0.48

0