In [58]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

from liptrf.models.vit import L2Attention, ViT
from liptrf.models.layers import trunc, l2_normalize

In [59]:
model = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=200)

In [60]:
sum(p.numel() for p in model.parameters())

5563016

In [61]:
model.patch_embed.proj.weight.shape

torch.Size([192, 3, 16, 16])

In [68]:
class PatchEmbedX(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, iter=5, lmbda=2.5, relax=1, lr=1, eta=1e-7):
        super(PatchEmbedX, self).__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten
        self.iter = iter 
        self.lmbda = lmbda
        self.relax = relax 
        self.lr = lr 
        self.eta = eta
        
        self.weight = nn.Parameter(torch.Tensor(embed_dim, in_chans, patch_size[0], patch_size[0]))
        self.bias = nn.Parameter(torch.Tensor(embed_dim))
        self.rand_x = nn.Parameter(trunc([1, in_chans, patch_size[0], patch_size[0]]), requires_grad=False)

        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]})."
        assert W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]})."
        x = F.conv2d(x, self.weight, self.bias, stride=self.patch_size)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x

    def lipschitz(self):
        for i in range(self.iter):
            x = l2_normalize(self.rand_x)
            x_p = F.conv2d(x, self.weight, stride=self.patch_size)
            self.rand_x = nn.Parameter(F.conv_transpose2d(x_p, self.weight, stride=self.patch_size), requires_grad=False)
        
        Wx = F.conv2d(self.rand_x, self.weight, stride=self.patch_size)
        self.lc = torch.sqrt(torch.sum(Wx**2) / (torch.sum(x**2) + 1e-9)).data.cpu()
        del x, x_p, Wx
        torch.cuda.empty_cache()
        return self.lc

    def apply_spec(self):
        fc = self.weight.clone().detach()
        # print (fc.max())
        fc = fc * 1 / (max(1, self.lc / self.lmbda))
        # print (fc.max(), self.lc, self.lmbda)
        self.weight = nn.Parameter(fc)
        del fc
        torch.cuda.empty_cache()

    def prox(self):
        self.lipschitz()
        self.lmbda = self.relax
        self.apply_spec()
        self.prox_weight = self.weight.clone() #/ self.relax
        self.proj_weight = 2 * self.prox_weight - self.weight.clone()
        self.proj_weight_n = self.proj_weight.clone()

    def proj(self):
        # if torch.norm()
        if torch.norm(self.proj_weight_n-self.proj_weight, 'fro') < self.eta * torch.norm(self.weight, 'fro'):
            return 

        z = F.linear(self.inp, self.proj_weight_n) - self.out
        if len(z.shape) == 3:
            cjn = torch.mean(torch.sum(z**2, dim=[0, 1]) - self.eta)
        else:
            cjn = torch.mean(torch.sum(z**2, dim=0) - self.eta)

        del_wn = torch.zeros(self.proj_weight_n.shape)
        if cjn > 0:
            if len(self.inp.shape) == 3:
                num = 2 * torch.sum(torch.einsum("bnjd,bnci->bndc", 
                                    z.unsqueeze(-2), 
                                    self.inp.unsqueeze(-1)), dim=[0, 1])
            else:
                num = 2 * torch.sum(torch.einsum("bjd,bci->bdc", 
                                    z.unsqueeze(-2), 
                                    self.inp.unsqueeze(-1)), dim=0)
            num = num / self.out.shape[-1]
            den = torch.norm(num, 'fro')**2
            del_wn = -cjn * num / den 
        
        L = torch.sum(del_wn**2)
        if L > 1e-22:
            cW = self.proj_weight - self.proj_weight_n

            pi_n =  -1 * (cW.T.flatten().unsqueeze(0) @ del_wn.flatten().unsqueeze(1))
            mu_n = torch.norm(cW, p=2)**2
            vu_n = torch.norm(del_wn, p=2)**2 
            chi_n = mu_n * vu_n - pi_n**2 

            if chi_n < 0:
                chi_n = 0

            # print (del_wn.max(), vu_n, chi_n, pi_n, mu_n)
            if (chi_n == 0) and (pi_n >= 0):
                self.proj_weight_n = self.proj_weight_n + del_wn
            elif (chi_n > 0) and ((pi_n * vu_n) >= chi_n):
                self.proj_weight_n = self.proj_weight + (1  + pi_n/vu_n) * del_wn
            elif (chi_n > 0) and ((pi_n * vu_n) < chi_n):
                self.proj_weight_n = self.proj_weight_n + vu_n / chi_n * (pi_n * cW - mu_n * del_wn)
            else:
                raise Exception("Error")

    def update(self):
        self.proj_weight = self.proj_weight_n
        self.weight += self.lr * (self.prox_weight - self.proj_weight)

In [72]:
pe = PatchEmbedX(iter=100)

In [79]:
inp = torch.randn(1, 3, 224, 224)
pe(inp).shape
pe.inp = inp

In [80]:
pe.prox()

In [81]:
pe.proj()

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D