In [1]:
import numpy as np

import torch
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import pytorch_lightning as pl

import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
import seaborn as sns

In [2]:
mean_vals = np.array([0.485, 0.456, 0.406])
std_vals = np.array([0.229, 0.224, 0.225])

net_name = 'efficientnet_b0'
pretrained = torch.hub.load("pytorch/vision", net_name, pretrained=True)
pretrained.eval();
pretrained.training

Using cache found in /Users/grisha.oryol/.cache/torch/hub/pytorch_vision_main
  warn(f"Failed to load image Python extension: {e}")


False

In [3]:
def smoothing_kernel(kernel_size=5, rad=1):
    kernel = np.zeros((kernel_size, kernel_size))
    for i in range(kernel_size):
        for j in range(kernel_size):
            kernel[i][j] = np.exp(- ((i - kernel_size // 2) ** 2 + (j - kernel_size // 2) ** 2) / rad ** 2)
    kernel /= kernel.sum()
    return kernel

    
class PreSmoother(torch.nn.Module):
    def __init__(self, kernel_size=5, rad=1):
        super(PreSmoother, self).__init__()
        self.kernel_size = kernel_size
        self.smoother = torch.nn.Conv2d(1, 1, kernel_size, bias=False, padding=0)
        self.kernel = smoothing_kernel(kernel_size, rad)
        self.smoother.weight = torch.nn.Parameter(torch.tensor(self.kernel).unsqueeze(0).unsqueeze(0).float(), requires_grad=False)

        
    def forward(self, x):
        x = x.reshape(-1, 1, x.shape[2],  x.shape[3])
        smoothed = self.smoother(x)
        smoothed = smoothed.reshape(1, -1, smoothed.shape[2],  smoothed.shape[3])
        #smoothed = smoothed[:, :, self.kernel_size // 2: -(self.kernel_size // 2), self.kernel_size // 2: -(self.kernel_size // 2)]
        return smoothed


class NewNet(pl.LightningModule):
    def __init__(
                self, class_idx, inp_size=224, inp=None, min_inp=-2, max_inp=2, init_ampl=0.1, greyscale=True,
                kernel_size=5,
                rad=1,
                num_classes=1000,
                period=100,
                lr=0.00001,
                max_freq=2,
                corr_class_mult=1,
    ):
        self.min_inp = min_inp
        self.max_inp = max_inp
        self.inp_size = inp_size
        self.inp_size_raw = inp_size + kernel_size - 1
        super(NewNet, self).__init__()
        if inp is not None:
            self.inp = torch.nn.Parameter(inp)
        else:
            if greyscale:
                one_channel = get_random_input(self.inp_size_raw, max_freq=max_freq)
                rand_inp = init_ampl * torch.tensor(np.stack([one_channel.numpy().copy() for _ in range(3)])).unsqueeze(0)
            else:
                rand_inp = init_ampl * torch.tensor(np.stack([get_random_input(self.inp_size_raw, max_freq=max_freq).numpy().copy() for _ in range(3)])).unsqueeze(0)
            self.inp = torch.nn.Parameter(rand_inp)
        self.pretrained = pretrained
        print(self.pretrained.classifier[1].weight)

        self.pretrained.eval()
        for param in self.pretrained.parameters():
            param.requires_grad = False
            
        self.target = torch.zeros(num_classes).unsqueeze(0)
        self.target[0, class_idx] = 1
        self.class_idx = class_idx
        
        self.pre_smoother = PreSmoother(kernel_size, rad)
        self.init_inp = torch.tensor(self.inp.detach()[0].numpy().copy())
        self.init_smoothed = torch.tensor(self.pre_smoother(self.inp).detach()[0].numpy().copy())
        
        self.period = period
        self.lr = lr
        self.corr_class_mult = corr_class_mult
        
        self.mean_vals = torch.tensor(mean_vals).float()
        self.std_vals = torch.tensor(std_vals).float()
        
        self.hparams = dict(class_idx=class_idx, 
                            inp_size=inp_size, 
                            min_inp=min_inp, 
                            max_inp=max_inp, 
                            init_ampl=init_ampl, 
                            greyscale=greyscale,
                            kernel_size=kernel_size,
                            rad=rad,
                            num_classes=num_classes,
                            lr=lr,
                            max_freq=max_freq,
                            corr_class_mult=corr_class_mult,
                           )
        
        self.example_input_array = []
    
    
    def on_train_start(self):
        
        self.hparams["gradient_clip_val"] = self.trainer.gradient_clip_val
        print(self.hparams)
        self.logger.log_hyperparams(self.hparams)
    
        
    def forward(self):
        self.pre_smoother.eval()
        self.smoothed = self.pre_smoother(self.inp)
        
        self.rescaled = ((self.smoothed.permute(2, 3, 0, 1) - self.mean_vals) / self.std_vals).permute(2, 3, 0, 1)
        
        self.pretrained.eval()
        for param in self.pretrained.parameters():
            param.requires_grad = False
        self.probs = torch.softmax(self.pretrained(self.rescaled), dim=-1)
        return self.probs

    def training_step(self, batch, batch_nb):

        probs = self.forward()
        reg = torch.nn.functional.relu(- self.inp).sum() + torch.nn.functional.relu(self.inp - 1).sum()
        #inp_for_smoother = self.inp.reshape(-1, 1, self.inp_size,  self.inp_size)
        mults = 1 + self.corr_class_mult * self.target
        prob_mse = (mults * ((probs - self.target) ** 2)).sum()  #(probs * (0.5 - self.target)).sum() 
        loss = prob_mse + reg
        
        prob_mse_for_stopping = ((probs - self.target) ** 2).sum()
        self.log("z_loss_for_stopping", prob_mse_for_stopping, on_step=True)
        if prob_mse_for_stopping < 0.0001:
            self.trainer.should_stop = True
        
        self.tensorboard = self.logger.experiment
        logged_scalars = {
                "prob_mse": prob_mse,
                "reg": reg,
                "total_loss": loss
        }
        if self.trainer.global_step % (100 * self.period) == 0:
            print(loss.item())

        
        self.tensorboard.add_scalars("0_losses", logged_scalars,
                global_step=self.trainer.global_step,
                )
        
        self.tensorboard.add_scalar("0_prob_corr_class", self.probs[0][self.class_idx],
                global_step=self.trainer.global_step,
                )
        
        self.tensorboard.add_scalar("mean_diff_init", ((self.init_inp - self.inp[0]) ** 2).mean() ** 1/2, global_step=self.trainer.global_step)
        if self.trainer.global_step % self.period == 0:
            self.tensorboard.add_image("1_inp", self.inp[0], global_step=self.trainer.global_step)
            self.tensorboard.add_image("1_smoothed", self.smoothed[0], global_step=self.trainer.global_step)
            img_diff = normalize_for_logging((self.init_inp - self.inp[0]).abs().sum(axis=0))
            self.tensorboard.add_image("2_inp_diff", img_diff, global_step=self.trainer.global_step, dataformats="HW")
            
            img_diff_color = normalize_for_logging((self.init_inp - self.inp[0]).abs(), eps=0)
            self.tensorboard.add_image("2_inp_diff_color", img_diff_color, global_step=self.trainer.global_step, dataformats="CHW")
            
            smoothed_diff = normalize_for_logging((self.init_smoothed - self.smoothed[0]).abs().sum(axis=0), eps=0)
            self.tensorboard.add_image("2_smoothed_diff", smoothed_diff, global_step=self.trainer.global_step, dataformats="HW")
            
            smoothed_diff_color = normalize_for_logging((self.init_smoothed - self.smoothed[0]).abs())
            self.tensorboard.add_image("2_smoothed_diff_color", smoothed_diff_color, global_step=self.trainer.global_step, dataformats="CHW")
             
            fig = Figure()
            canvas = FigureCanvas(fig)
            ax = fig.gca()
            ax.plot(self.target[0].detach().numpy())
            ax.plot(probs[0].detach().numpy())
            ax.axis('off')
            canvas.draw()       # draw the canvas, cache the renderer
            prob_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
            prob_image = prob_image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            self.tensorboard.add_image("3_probs_plots", prob_image, global_step=self.trainer.global_step, dataformats="HWC")
            
        return loss
    
    def on_after_backward(self):
        
        step = self.trainer.global_step
        
        if self.inp.grad is not None:
            grad_img = model.inp.grad[0].abs()
            
            self.log("9_max_grad_orig", grad_img.max(), on_step=True)
            self.log("9_mean_grad_orig", grad_img.mean(), on_step=True)
            
            
            if self.trainer.global_step % self.period == 0:
                grad_img_norm = 255 * grad_img / grad_img.max()
                self.tensorboard.add_image("9_grad_img", grad_img_norm, 
                                           global_step=step, dataformats="CHW")
                grads_grayscale = normalize_for_logging((model.inp.grad[0] ** 2).sum(axis=0))
                self.tensorboard.add_image("9_grad_grayscale", grads_grayscale,
                                           global_step=step, dataformats="HW")
                self.tensorboard.add_histogram("9_grad_hist", self.inp.grad, global_step=step)        

    def configure_optimizers(self):
        optimizer = torch.optim.SGD([model.inp], lr=self.lr)
        return optimizer
    
    
def normalize_for_logging(t, eps=0.001):
    min_val = t.flatten().quantile(eps)
    max_val = t.flatten().quantile(1 - eps)
    return (t - min_val) / (max_val - min_val)
    
    
class DummyDataset(torch.utils.data.IterableDataset):
    def __init__(self, num_steps):
        super(DummyDataset).__init__()
        self.num_steps = num_steps
    
    def __iter__(self):
        return iter(range(self.num_steps))    

    
def get_random_input(inp_size, max_freq=5):
    """ 
    Generate random smooth input by combining random harmonics
    """
    inp = sum([get_harmonic(Tx, Ty, inp_size) for Tx in range(max_freq) for Ty in range(max_freq)])
    return (inp - inp.min()) / (inp.max() - inp.min())


def get_harmonic(Tx, Ty, inp_size):
    """ 
    Generate a random harmonic
    """
    lx = torch.linspace(0, np.pi * Tx, inp_size).unsqueeze(0) + 2 * np.pi * np.random.rand()
    ly = torch.linspace(0, np.pi * Ty, inp_size).unsqueeze(1)  + 2 * np.pi * np.random.rand()
    return np.random.rand() * torch.sin(lx) * torch.sin(ly)

In [None]:
for init_ampl in [0.9, 0.1, 0.01]:
    for class_idx in range(0, 1000, 100):
        for kernel_size in [5, 9, 19]:
            for _ in [1, 2]:
                logger = TensorBoardLogger("logs_efficientnet_b0", name=f"{net_name}-{class_idx}", default_hp_metric=False, log_graph=True)
                trainer = pl.Trainer(logger=logger, max_epochs=1, gradient_clip_val=0.1)

                model = NewNet(class_idx=class_idx, period=10, init_ampl=init_ampl, greyscale=True, lr=1e1, kernel_size=kernel_size, rad=kernel_size/2, corr_class_mult=10)
                train_dataloader = torch.utils.data.DataLoader(DummyDataset(20000))
                trainer.fit(model=model, train_dataloader=train_dataloader)