In [1]:
import torch
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

# Transforms images to a PyTorch Tensor
tensor_transform = transforms.ToTensor()

# Download the MNIST Dataset
dataset = datasets.MNIST(root="./data",
                         train=True,
                         download=True,
                         transform=tensor_transform)

# DataLoader is used to load the dataset
# for training
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=1024,
                                     shuffle=True)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda', index=0)

In [3]:
class FFHead(torch.nn.Module):
    def __init__(self, input_size, output_size, lr=1e-4):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.lr = lr

        self.W = torch.randn(input_size, output_size)
        self.b = torch.randn(output_size)

    def update(self, new_img, past_img_out):
        # 由于是第一层，所以可以拿到真实标签y和上次操作导致的预测值y_hat
        # 如果y_hat和y相差很大，说明这次操作是不合理的，需要调整
        # 由于autoencoder的原理，我们可以估计理想的l(F(x))应该等价于一个线性模型 y = x
        # 达到理想时：l(F(past_img_out)) = l(new_img) = past_img_out
        # 因为l是个线性层，所以后续模型也可以认为是一个线性层，假设为F
        # new_img = (img @ W + b) @ F + b'
        # 可以推算出要拟合需要的W和b：past_img_out = new_img @ w_target + b_target
        
        # past_img_out = img @ self.W + self.b
        # 因为w_target和b_target不能同时计算，所以只能先牺牲一个的精度
        b_target = self.b
        w_target = torch.inverse(new_img) @ (past_img_out + b_target)
        b_target = past_img_out - new_img @ w_target

        self.W = self.target * self.lr + self.W * (1 - self.lr)
        self.b = self.target * self.lr + self.b * (1 - self.lr)

        return self.forward(new_img)
    
    def forward(self, x):
        x = x @ self.W
        x = x + self.b
        return x

In [None]:
import tqdm
import numpy as np


class FFAE(torch.nn.Module):
    def __init__(self, eps=1e-2, init_dataloader=None):
        super().__init__()
        self.eps = eps
        self.units = torch.nn.Sequential(
            FFHead(28*28, 128, activation=torch.nn.LeakyReLU()),
            FFHead(128, 64,  activation=torch.nn.LeakyReLU()),
            FFHead(64, 36,  activation=torch.nn.LeakyReLU()),
            FFHead(36, 18,  activation=torch.nn.LeakyReLU()),
            FFHead(18, 9,),
            FFHead(9, 18,),
            FFHead(18, 36,  activation=torch.nn.LeakyReLU()),
            FFHead(36, 64,  activation=torch.nn.LeakyReLU()),
            FFHead(64, 128,  activation=torch.nn.LeakyReLU()),
            FFHead(128, 28*28, activation=torch.nn.Sigmoid()),
        )
        self.loss = torch.nn.MSELoss()

        if init_dataloader is not None:
            for image, _ in (pbar := tqdm.tqdm(loader)):
                # Reshaping the image to (-1, 784)
                x = image.reshape(-1, 28*28).to(device)
                loss = 0
                for unit in self.units:
                    loss, x = unit.init_parameters(x)
                pbar.set_description(f"Loss: {loss.item():.4f}")

    def forward(self, x):
        with torch.no_grad():
            for unit in self.units:
                x = unit(x)
        return x

    def update(self, img, times=10):
        # img = x.clone().detach()
        x = img.clone().detach()
        buffer = []
        with torch.no_grad():
            for unit in self.units[:-1]:
                x = unit(x)
                buffer.append(x)

        _, x = self.units[-1].update(x, torch.logit(img, self.eps))

        all_loss = []
        for t in range(times):
            # x = img.clone().detach()
            _ = self.units[0].update(x, buffer.pop(0))
            x = self.units[0](img)
            buffer.append(x)
            for unit in self.units[1:-1]:
                _, x = unit.update(x, buffer.pop(0))
                buffer.append(x)

            loss, x = self.units[-1].update(x, img)

            all_loss.append(loss.detach().cpu().numpy())
        return np.mean(all_loss)
