In [None]:
import sys
path = '/home/zhuxiaopei_srt/data3/env_attack/attack'
if path not in sys.path:
    sys.path.append(path)
    print(sys.path)
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from PIL import Image
from modules.img_utils import tensor2img

In [None]:
## 硬件参数
device_id = 0

## 日志
log_root_dir = '../logs/'
log_name = 'test'
log_comment = ''

## 图片背景
background_dir = '../background/'
background = 'forest0.jpg'

## 主色个数
main_color_cnt = 12
tau = 0.4

## 纹理生成器参数
# texture_type = 'direct'
texture_type = 'softmax'
texture_org_shape = [512, 512]
fixed_softmax_g = True

## 纹理处理器参数
# mask_type = 'image'
mask_type = 'screen'

mask_image = {
    'mask_file': '../models/rangerover/mask_eext.png',
    # 'mask_file': '../models/rangerover/mask_screen2.png',
    'random_brightness': (0.8, 1.1), 
    'random_bias_r': 0.02,
}

mask_screen = {
    'screens': [
        # front
        (1090,   45, 1368,  539, [-0.1,  0.0,  1.0], 
        #  None),
         './screens/front4.pth'),

        # left
        (1320, 1000, 1565, 1525, [ 0.0, -1.0,  0.0], 
        #  None),
         './screens/left4.pth'),

        # top
        (1625,  390, 1935,  941, [ 0.0,  0.0,  1.0], 
        #  None),
         './screens/top4.pth'),
    ],
    'random_brightness': (0.8, 1.1), 
    'random_bias_r': 0.02,
}

## 汽车模型
obj_file = '../models/rangerover/rangerover.obj'

## 渲染器参数
renderer_config = {
    'img_size': 640,
    'rander_scale': 0.345,
    'rander_car_at': (0, -.235, 0.07),
    'fov': 60,
}

## 检测模型
img_size = 640
detector_type = 'yolov9'
# detector_type = 'mmdet'

detector_yolov9 = {
    'weight': '../checkpoints/yolov9-c-converted.pt',
}
# detector_mmdet = {
#     'checkpoint_dir': '../checkpoints/',
#     'model_name': 'faster_rcnn_r50_fpn_1x_coco',
#     'config_dir': '../mmdet/configs/faster_rcnn/',
# }

## 数据生成器
data_config = {
    # 距离
    'Distance': (8, 12),
    # 俯仰角 0~90 deg
    'Sigma': (15, 45),
    'Sigma': (0, 50),
    # 方位角
    #  +135    Y +90    +45
    #       /----R---+    
    # +180 (F [car] B|  +0 -> X
    #       \----L---+
    #  -135     -90     -45
    # 'Theta': (-180, +180),
    'Theta': (-190,  -30),
    # 
    'Roll': (-5, +5),
}

## 攻击目标
target_class = [2, 5, 7]

## 检测参数
tester_config = {
    'bound_size': 0.2,
    'confidence_threshold': [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
    'sample_cnt': 20,
}

## 训练参数
# epochs = 1
# train_size = 200
# val_size = 30
epochs = 20
train_size = 1000
val_size = 1000

batch_size = 1
lr = 0.01

n_smooth = 10.0

# test_at = [0,1,2,4,10,20]
test_at = [epochs]

realtime = True

In [None]:
# 批量运行用
if sys.argv[1] == 'batch':
    arg_config = sys.argv[2].split(';')
    for s in arg_config:
        exec(s)
    log_comment = arg_config
    
background = background_dir + background

In [None]:
# 显示图片
def show_in_notebook(image, **options):
    display(image)
    return True

# 注册显示函数到ImageShow
from PIL import ImageShow
ImageShow.register(show_in_notebook)

In [None]:
# 设置 GPU / CPU
print(torch.__version__)
# print(os.environ['PATH'])

assert(torch.cuda.is_available())

device = torch.device(device_id)
print(device)
torch.cuda.set_device(device)

# if cpu_limit > 0:
#     os.sched_setaffinity(0, list(range(cpu_limit)))

In [None]:
# 日志文件
from modules.logger import Logger
logger = Logger(log_root_dir, log_name)
print(logger.log_dir)

logger.pprint(
    comment=log_comment,

    background=background,
    
    main_color_cnt=main_color_cnt,
    tau=tau,
    fixed_softmax_g=fixed_softmax_g,

    texture_type=texture_type,
    texture_org_shape=texture_org_shape,
    
    mask_type=mask_type,
    mask_image=mask_image, 
    mask_screen=mask_screen, 

    # obj_file=obj_file,
    # renderer_config=renderer_config,

    detector_type=detector_type,
    detector_yolov9=detector_yolov9,
    detector_mmdet=detector_mmdet,

    data_config=data_config,

    target_class=target_class,
    # tester_config=tester_config,

    # epochs=epochs,
    # train_size=train_size,
    # val_size=val_size,

    # lr=lr,
    n_smooth=n_smooth,
)

In [None]:
# 主色提取
from modules.img_utils import sum_images, get_dominant_colors, show_colors

imgs = Image.open(background)

main_colors = get_dominant_colors(imgs, main_color_cnt)
logger.pprint(main_colors=main_colors)

main_color_show = show_colors(main_colors)
logger.save_graph('main_colors.png', main_color_show)
main_color_show.show()

In [None]:
# 汽车模型
from modules.car_model import CarModel
car = CarModel(device, obj_file)

In [None]:
# 纹理后处理
from modules.car_model import ImageMask, ScreenMask
if mask_type == 'image':
    car_tex_mask = ImageMask(device, **mask_image)
elif mask_type == 'screen':
    car_tex_mask = ScreenMask(device, uv_shape=car.uv_shape, **mask_screen)
else:
    raise

In [None]:
# 渲染器
from modules.renderer import Renderer
normal_render = Renderer(device, **renderer_config)

dis_min, dis_max = data_config['Distance']
sigma_min, sigma_max = data_config['Sigma']
render_test_pos = [
    (dis_max, sigma_min, 60, 0),
    (dis_max, sigma_min, 60, -15),
    (dis_max, sigma_min, 60, 15),
    (dis_min, sigma_min, 60, 0),
    (dis_min, sigma_max, 60, 0),
    (dis_min, sigma_max, 120, 0),
    (dis_min, sigma_max, 180, 0),
    (dis_min, sigma_max, -120, 0),
    (dis_min, sigma_max, -60, 0),
    (dis_min, sigma_max, 0, 0),
]
sum_images([tensor2img(normal_render.render(car.to_mesh(None), pos)) for pos in render_test_pos]).show()

In [None]:
# 要攻击的网络
from modules.detector import Detector
net = Detector(device, img_size, detector_type, **detector_yolov9)

In [None]:
# 测试集构造
from modules.dataset_generator import train_data_sample, SimpleBackgroundDataset

val_data_set = train_data_sample(device, val_size, background, img_size, **data_config)
for i in range(3):
    print(val_data_set[0][i])

test_dataset = SimpleBackgroundDataset(device, val_data_set, car_tex_mask, car, normal_render)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# tensor2img(test_dataset[0]).show()

In [None]:
# 攻击效果测试器 (ASR)
from modules.attack_tests import AttackTester
texture_tester = AttackTester(test_dataloader, net, logger, 
                              target_class=target_class, img_size=img_size, 
                              **tester_config)

# 初始纹理测试
test_dataset.set_texture(None)
# texture_tester.test('Origin')

# Baseline测试
from modules.texture_optim import DirectTexture, GumbleSoftmaxTexture

# test_dataset.set_texture(torch.tensor([[[1.,1.,1.]]], device=device))
# texture_tester.test('BL_white')

# for size in [128, 256, 512, 1024, 2048]:
#     test_dataset.set_texture(DirectTexture(device, [size, size]).texture())
#     texture_tester.test(f'BL_rd_{size}')

# test_dataset.set_texture(torch.tensor([[main_colors[0]]], device=device) / 255.0)
# texture_tester.test('BL_mc0')

# for size in [128, 256, 512, 1024, 2048]:
#     t=GumbleSoftmaxTexture(device, [size, size], len(main_colors), tau=0.1)
#     t.set_main_colors(main_colors)
#     test_dataset.set_texture(t.texture())
#     texture_tester.test(f'BL_mc_{size}')

In [None]:
# 纹理生成
if texture_type == 'direct':
    texture = DirectTexture(device, texture_org_shape)
elif texture_type == 'softmax':
    texture = GumbleSoftmaxTexture(device, texture_org_shape, main_color_cnt, tau, fixed_softmax_g)
    texture.set_main_colors(main_colors)
else:
    raise

optim = torch.optim.Adam(texture.params(), lr=lr)  # 攻击优化器Adam

test_dataset.set_texture(texture.texture())
tensor2img(test_dataset[0]).show()
tensor2img(test_dataset.get_texture()).resize((400, 400)).show()

if 0 in set(test_at):
    texture_tester.test('Epoch_0')

In [None]:
# 纹理管理器
if realtime:
    from modules.tex_server import TexServer
    server = TexServer(logger.get_path('realtime/'), sample_per=50)
    server.new_tex(tensor2img(texture.texture()))

In [None]:
# 训练
from modules.img_utils import loss_smooth

for i_epoch in range(epochs):
    # 初始化训练数据集和优化值
    train_data_set = train_data_sample(device, train_size, background, img_size, **data_config)

    dataset = SimpleBackgroundDataset(device, train_data_set, car_tex_mask, car, normal_render)
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)

    pbar = tqdm(dataloader)

    for i, (total_img) in enumerate(pbar):
        # 损失结果加权和
        loss = 0

        lcls, lobj = net.cpt_loss(total_img, target_class)
        loss += 0.5 * lcls + lobj 

        l_smooth = loss_smooth(texture.texture())
        loss += n_smooth * l_smooth
        
        # 反向传播
        optim.zero_grad()
        if loss > 0:
            loss.backward(retain_graph=True)
        optim.step()
        
        # 更新纹理
        if realtime:
            server.new_tex(tensor2img(texture.texture()))
        dataset.set_texture(texture.texture())

        pbar.set_description('cls %.3f,obj %.3f,smo %.3f' % 
                            (lcls.data.cpu().numpy(),
                            lobj.data.cpu().numpy(),
                            l_smooth.data.cpu().numpy()
                            ))
    
    # 测试
    if i_epoch+1 in set(test_at):
        test_dataset.set_texture(texture.texture())
        texture_tester.test(f'Epoch_{i_epoch+1}')

# 保存纹理变化视频
if realtime:
    server.save_video()

# 保存最终prob map
final_param: torch.Tensor = texture.params()[0]
torch.save(final_param.detach().cpu(), logger.get_path('final.pth'))