In [1]:
import os
import pickle
import time

import scipy.io as sio
import torch
import torch.nn as nn
from torchinfo import summary

import datasets
import plots
import transformer
import utils

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
class AutoEncoder(nn.Module):
    def __init__(self, P, L, size, patch, dim):
        super(AutoEncoder, self).__init__()
        self.P, self.L, self.size, self.dim = P, L, size, dim
        self.encoder = nn.Sequential(
            nn.Conv2d(L, 128, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            nn.BatchNorm2d(128, momentum=0.9),
            nn.Dropout(0.25),
            nn.LeakyReLU(),
            nn.Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            nn.BatchNorm2d(64, momentum=0.9),
            nn.LeakyReLU(),
            nn.Conv2d(64, (dim*P)//patch**2, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            nn.BatchNorm2d((dim*P)//patch**2, momentum=0.5),
        )

        self.vtrans = transformer.ViT(image_size=size, patch_size=patch, dim=(dim*P), depth=2,
                                      heads=8, mlp_dim=12, pool='cls')
        
        self.upscale = nn.Sequential(
            nn.Linear(dim, size ** 2),
        )
        
        self.smooth = nn.Sequential(
            nn.Conv2d(P, P, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.Softmax(dim=1),
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(P, L, kernel_size=(1, 1), stride=(1, 1), bias=False),
            nn.ReLU(),
        )

    @staticmethod
    def weights_init(m):
        if type(m) == nn.Conv2d:
            nn.init.kaiming_normal_(m.weight.data)

    def forward(self, x):
        abu_est = self.encoder(x)
        cls_emb = self.vtrans(abu_est)       
        cls_emb = cls_emb.view(1, self.P, -1)
        # P表示端元数目
        abu_est = self.upscale(cls_emb).view(1, self.P, self.size, self.size)
        abu_est = self.smooth(abu_est)
        re_result = self.decoder(abu_est)
        return abu_est, re_result

In [12]:
model = AutoEncoder(P = 3, L = 156, size = 95, patch = 5, dim = 200)
batch_size = 1
summary(model, input_size = (batch_size, 156, 95, 95))

Layer (type:depth-idx)                                  Output Shape              Param #
AutoEncoder                                             [1, 3, 95, 95]            --
├─Sequential: 1-1                                       [1, 24, 95, 95]           --
│    └─Conv2d: 2-1                                      [1, 128, 95, 95]          20,096
│    └─BatchNorm2d: 2-2                                 [1, 128, 95, 95]          256
│    └─Dropout: 2-3                                     [1, 128, 95, 95]          --
│    └─LeakyReLU: 2-4                                   [1, 128, 95, 95]          --
│    └─Conv2d: 2-5                                      [1, 64, 95, 95]           8,256
│    └─BatchNorm2d: 2-6                                 [1, 64, 95, 95]           128
│    └─LeakyReLU: 2-7                                   [1, 64, 95, 95]           --
│    └─Conv2d: 2-8                                      [1, 24, 95, 95]           1,560
│    └─BatchNorm2d: 2-9                         

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, P, L, size, patch, dim):
        super(AutoEncoder, self).__init__()
        self.P, self.L, self.size, self.dim = P, L, size, dim
        # self.encoder = nn.Sequential(
        #     nn.Conv2d(L, 128, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
        #     nn.BatchNorm2d(128, momentum=0.9),
        #     nn.Dropout(0.25),
        #     nn.LeakyReLU(),
        #     nn.Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
        #     nn.BatchNorm2d(64, momentum=0.9),
        #     nn.LeakyReLU(),
        #     nn.Conv2d(64, (dim*P)//patch**2, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
        #     nn.BatchNorm2d((dim*P)//patch**2, momentum=0.5),
        # )
        
        self.encoder = nn.Sequential(
            nn.Conv2d(L, 128, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            nn.BatchNorm2d(128, momentum=0.9),
            nn.Dropout(0.25),
            nn.LeakyReLU(),
            nn.Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            nn.BatchNorm2d(64, momentum=0.9),
            nn.LeakyReLU(),
            nn.Conv2d(64, (dim*P)//patch**2, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
            nn.BatchNorm2d((dim*P)//patch**2, momentum=0.5),
        )

        self.vtrans = transformer.ViT(image_size=size, patch_size=patch, dim=(dim*P), depth=2,
                                      heads=8, mlp_dim=12, pool='cls')
        
        self.upscale = nn.Sequential(
            nn.Linear(dim, size ** 2),
        )
        
        self.smooth = nn.Sequential(
            nn.Conv2d(P, P, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.Softmax(dim=1),
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(P, L, kernel_size=(1, 1), stride=(1, 1), bias=False),
            nn.ReLU(),
        )

    @staticmethod
    def weights_init(m):
        if type(m) == nn.Conv2d:
            nn.init.kaiming_normal_(m.weight.data)

    def forward(self, x):
        abu_est = self.encoder(x)
        cls_emb = self.vtrans(abu_est)       
        cls_emb = cls_emb.view(1, self.P, -1)
        # P表示端元数目
        abu_est = self.upscale(cls_emb).view(1, self.P, self.size, self.size)
        abu_est = self.smooth(abu_est)
        re_result = self.decoder(abu_est)
        return abu_est, re_result