In [1]:
import torch
import torch.nn as nn

H, W = 6, 6
C_in = 10
C_out = 20
kernel_size = 3

# initial falttened vector
x = torch.randn(1, C_in*H*W)

# input tensor
x = x.view(-1, C_in, H, W)
print(x.size())

# Conv2d
conv = nn.Conv2d(C_in, C_out, kernel_size=kernel_size, stride=1, padding=0, bias=True)
y = conv(x)

# Non-linearity
relu = nn.ReLU()
y = relu(y)
print(y.size())

# num of parameters in Conv2d
num_params = sum(p.numel() for p in conv.parameters())
print(num_params)

torch.Size([1, 10, 6, 6])
torch.Size([1, 20, 4, 4])
1820


In [128]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CustomConv2d(nn.Module):
    def __init__(self, in_channels, out_channels_map, kernel_size, stride=1, padding=0, nn_hidden_dim=4):
        """
        Parameters:
          in_channels: number of input channels.
          out_channels_map: 2D torch.Tensor of shape (H_out, W_out) containing, at each (i,j), 
                            the desired number of output channels.
          kernel_size: int or tuple, the size of the kernel.
          stride, padding: convolution parameters.
        """
        super(CustomConv2d, self).__init__()
        self.in_channels = in_channels
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.stride = stride
        self.padding = padding
        
        # Save the desired output dimensions per spatial location.
        # Here we assume out_channels_map is a tensor of shape (H_out, W_out) with integer entries.
        self.out_channels_map = out_channels_map
        H_out, W_out = out_channels_map.shape
        self.H_out = H_out
        self.W_out = W_out
        
        # For each output pixel (i, j), create a unique weight and bias.
        # Weight for pixel (i,j) is of shape: (d_{ij}, in_channels * kernel_height * kernel_width)
        in_dim = in_channels * self.kernel_size[0] * self.kernel_size[1]
        self.weights = nn.ParameterDict()
        self.biases = nn.ParameterDict()
        self.out_weights = nn.ParameterDict()
        self.out_biases = nn.ParameterDict()
        for i in range(H_out):
            for j in range(W_out):
                out_channels = int(out_channels_map[i, j])
                key = f"{i}_{j}"
                # Initialize weights (you can choose a different initialization if desired)
                self.weights[key] = nn.Parameter(torch.randn(nn_hidden_dim, in_dim))
                self.biases[key] = nn.Parameter(torch.randn(nn_hidden_dim))
                self.out_weights[key] = nn.Parameter(torch.randn(out_channels, nn_hidden_dim))
                self.out_biases[key] = nn.Parameter(torch.randn(out_channels))
        
                
    
    def forward(self, x):
        """
        x: Input tensor of shape (B, in_channels, H, W).
        
        Returns:
          A nested list outputs[i][j] for each spatial position (i,j), where outputs[i][j] has shape (B, d_{ij}).
        """
        B, C, H, W = x.shape
        # Extract patches: result shape is (B, in_channels * k^2, L), where L = H_out * W_out.
        patches = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
        
        # Reshape patches to (B, L, in_dim)
        patches = patches.transpose(1, 2)
        L = patches.size(1)
        # Ensure that L equals H_out * W_out (this is true if the parameters are consistent)
        assert L == self.H_out * self.W_out, "Output spatial dimensions do not match out_channels_map shape."
        
        outputs = []
        idx = 0  # index over the flattened spatial locations
        for i in range(self.H_out):
            row_outputs = []
            for j in range(self.W_out):
                # patch: shape (B, in_dim) for the current spatial location.
                patch = patches[:, idx, :]  
                key = f"{i}_{j}"
                weight = self.weights[key]  # shape: (nn_hidden_dim, in_dim)
                bias = self.biases[key]     # shape: (nn_hidden_dim)
                out_weight = self.out_weights[key]
                out_bias = self.out_biases[key]
                # Compute the linear mapping: result shape (B, d_{ij})
                out_pixel = F.linear(patch, weight, bias)
                out_pixel = F.relu(out_pixel)
                out_pixel = F.linear(out_pixel, out_weight, out_bias)
                row_outputs.append(out_pixel)
                idx += 1
            outputs.append(row_outputs)
        # outputs is a list of lists such that outputs[i][j] is a tensor of shape (B, d_{ij})
        return outputs


In [134]:
# Define input tensor of shape (B, in_channels, H, W)
B, in_channels, H, W = 1, 4, 8, 8
x = torch.randn(B, in_channels, H, W)

# Suppose after applying a kernel of size 3 with stride 1 and padding 1,
# the output spatial dimensions remain 8x8.
# Let’s define a custom output channel map where the number of channels varies with position.
out_channels_map = torch.randint(4, 8, (8, 8))  # each pixel gets between 4 and 7 channels

# Create the custom convolution layer.
custom_conv = CustomConv2d(in_channels, out_channels_map, kernel_size=5, stride=1, padding=2)

# Forward pass.
outputs = custom_conv(x)

# Check the output shapes.
print(outputs[0][0].shape)
print(outputs[0][1].shape)
print(out_channels_map)

torch.Size([1, 6])
torch.Size([1, 6])
tensor([[6, 6, 4, 6, 5, 7, 5, 7],
        [4, 4, 4, 7, 4, 6, 4, 5],
        [6, 4, 5, 6, 5, 7, 4, 4],
        [6, 4, 7, 4, 6, 7, 4, 6],
        [4, 5, 6, 5, 5, 5, 6, 6],
        [7, 6, 4, 5, 6, 5, 7, 5],
        [4, 4, 5, 5, 6, 4, 5, 7],
        [5, 4, 7, 7, 6, 7, 6, 7]])


In [176]:
from vmc_torch.experiment.tn_model import *
class fTN_backflow_attn_Tensorwise_Model_v1(wavefunctionModel):
    """
        For each on-site fermionic tensor with specific shape, assign a narrow on-site projector MLP with corresponding output dimension.
        This is to avoid the large number of parameters in the previous model, where Np = N_neurons * N_TNS.
    """
    ...
    def __init__(self, ftn, max_bond=None, embedding_dim=32, attention_heads=4, nn_final_dim=4, nn_eta=1.0, dtype=torch.float32):
        super().__init__()
        self.param_dtype = dtype
        
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)

        # Flatten the dictionary structure and assign each parameter as a part of a ModuleDict
        self.torch_tn_params = nn.ModuleDict({
            str(tid): nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        })

        # Define the neural network
        input_dim = ftn.Lx * ftn.Ly
        phys_dim = ftn.phys_dim()
        
        self.nn = SelfAttn_block(
            n_site=input_dim,
            num_classes=phys_dim,
            embedding_dim=embedding_dim,
            attention_heads=attention_heads,
            dtype=self.param_dtype
        )
        # for each tensor (labelled by tid), assign a MLP
        self.mlp = nn.ModuleDict()
        for tid in self.torch_tn_params.keys():
            mlp_input_dim = ftn.Lx * ftn.Ly * embedding_dim
            tn_params_dict = {
                tid: params[int(tid)]
            }
            tn_params_vec = flatten_tn_params(tn_params_dict)
            self.mlp[tid] = nn.Sequential(
                nn.Linear(mlp_input_dim, nn_final_dim),
                nn.ReLU(),
                nn.Linear(nn_final_dim, tn_params_vec.numel()),
            )
            self.mlp[tid].to(self.param_dtype)

        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry

        # Store the shapes of the parameters
        self.param_shapes = [param.shape for param in self.parameters()]

        self.model_structure = {
            'fPEPS_backflow_attn_Tensorwise':
            {
                'D': ftn.max_bond(), 
                'Lx': ftn.Lx, 'Ly': ftn.Ly, 
                'symmetry': self.symmetry, 
                # 'nn_hidden_dim': nn_hidden_dim, 
                'nn_final_dim': nn_final_dim,
                'nn_eta': nn_eta, 
                'embedding_dim': embedding_dim,
                'attention_heads': attention_heads,
                'max_bond': max_bond,
            },
        }
        if max_bond is None or max_bond <= 0:
            max_bond = None
        self.max_bond = max_bond
        self.nn_eta = nn_eta
        self.tree = None
    
    def amplitude(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        params_vec = flatten_tn_params(params)

        # `x` is expected to be batched as (batch_size, input_dim)
        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            # Check x_i type
            if not type(x_i) == torch.Tensor:
                x_i = torch.tensor(x_i, dtype=self.param_dtype)
            else:
                if x_i.dtype != self.param_dtype:
                    x_i = x_i.to(self.param_dtype)
        
            # Get the NN correction to the parameters, concatenate the results for each tensor
            nn_features = self.nn(x_i)
            nn_features_vec = nn_features.view(-1)
            nn_correction = torch.cat([self.mlp[tid](nn_features_vec) for tid in self.torch_tn_params.keys()])
            # Add the correction to the original parameters
            tn_nn_params = reconstruct_proj_params(params_vec + self.nn_eta*nn_correction, params)
            # Reconstruct the TN with the new parameters
            psi = qtn.unpack(tn_nn_params, self.skeleton)
            # Get the amplitude
            amp = psi.get_amp(x_i, conj=True)

            if self.max_bond is None:
                amp = amp
                if self.tree is None:
                    opt = ctg.ReusableHyperOptimizer()
                    self.tree = amp.contraction_tree(optimize=opt)
                amp_val = amp.contract(optimize=self.tree)
            else:
                amp = amp.contract_boundary_from_ymin(max_bond=self.max_bond, cutoff=0.0, yrange=[0, psi.Ly//2-1])
                amp = amp.contract_boundary_from_ymax(max_bond=self.max_bond, cutoff=0.0, yrange=[psi.Ly//2, psi.Ly-1])
                amp_val = amp.contract()
                
            if amp_val==0.0:
                amp_val = torch.tensor(0.0)
            batch_amps.append(amp_val)

        # Return the batch of amplitudes stacked as a tensor
        return torch.stack(batch_amps)
    
class fTN_backflow_attn_conv_Model(wavefunctionModel):
    def __init__(self, ftn, max_bond=None, embedding_dim=32, attention_heads=4, nn_final_dim=128, nn_eta=1e-3, dtype=torch.float32):
        super().__init__()
        self.param_dtype = dtype
        
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)

        # Flatten the dictionary structure and assign each parameter as a part of a ModuleDict
        self.torch_tn_params = nn.ModuleDict({
            str(tid): nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        })

        # Define the neural network
        input_dim = ftn.Lx * ftn.Ly
        self.Lx = ftn.Lx
        self.Ly = ftn.Ly
        phys_dim = ftn.phys_dim()
        
        self.attn_block = SelfAttn_block(
            n_site=input_dim,
            num_classes=phys_dim,
            embedding_dim=embedding_dim,
            attention_heads=attention_heads,
            dtype=self.param_dtype
        )
        
        self.ts_dim_list = []
        for tid in self.torch_tn_params.keys():
            input_dim = ftn.Lx * ftn.Ly
            phys_dim = ftn.phys_dim()
            ts_params_dict ={
                tid: params[int(tid)]
            }
            ts_params_vec = flatten_tn_params(ts_params_dict)
            self.ts_dim_list.append(ts_params_vec.numel())
        self.out_channels_map = torch.tensor(self.ts_dim_list).view(ftn.Lx, ftn.Ly)

        # define the convolutional layer
        self.conv = CustomConv2d(embedding_dim, self.out_channels_map, kernel_size=5, stride=1, padding=2, nn_hidden_dim=nn_final_dim)
        self.conv.to(self.param_dtype)


        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry

        # Store the shapes of the parameters
        self.param_shapes = [param.shape for param in self.parameters()]

        self.model_structure = {
            'fPEPS_backflow_attn':{'D': ftn.max_bond(), 'Lx': ftn.Lx, 'Ly': ftn.Ly, 'symmetry': self.symmetry, 'nn_final_dim': nn_final_dim, 'nn_eta': nn_eta, 'max_bond': max_bond, 'embedding_dim': embedding_dim, 'attention_heads': attention_heads},
        }
        if max_bond is None or max_bond <= 0:
            max_bond = None
        self.max_bond = max_bond
        self.nn_eta = nn_eta
        self.tree = None
    
    def amplitude(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        params_vec = flatten_tn_params(params)

        # `x` is expected to be batched as (batch_size, input_dim)
        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            # Check x_i type
            if not type(x_i) == torch.Tensor:
                x_i = torch.tensor(x_i, dtype=self.param_dtype)
            else:
                if x_i.dtype != self.param_dtype:
                    x_i = x_i.to(self.param_dtype)
            # Get the NN correction to the parameters
            nn_features = self.attn_block(x_i)
            nn_features = nn_features.view(1, self.Lx, self.Ly, -1).permute(0, 3, 1, 2)
            nn_correction = self.conv(nn_features)
            # print(self.ts_dim_list)
            # print([len(nn_correction[i][j][0]) for i in range(self.Lx) for j in range(self.Ly)])
            nn_correction = torch.cat([nn_correction[i][j][0] for i in range(self.Lx) for j in range(self.Ly)], dim=0)

            # Add the correction to the original parameters
            tn_nn_params = reconstruct_proj_params(params_vec + self.nn_eta*nn_correction, params)
            # Reconstruct the TN with the new parameters
            psi = qtn.unpack(tn_nn_params, self.skeleton)
            # Get the amplitude
            amp = psi.get_amp(x_i, conj=True)

            if self.max_bond is None:
                amp = amp
                if self.tree is None:
                    opt = ctg.ReusableHyperOptimizer()
                    self.tree = amp.contraction_tree(optimize=opt)
                amp_val = amp.contract(optimize=self.tree)
            else:
                amp = amp.contract_boundary_from_ymin(max_bond=self.max_bond, cutoff=0.0, yrange=[0, psi.Ly//2-1])
                amp = amp.contract_boundary_from_ymax(max_bond=self.max_bond, cutoff=0.0, yrange=[psi.Ly//2, psi.Ly-1])
                amp_val = amp.contract()

            if amp_val==0.0:
                amp_val = torch.tensor(0.0)

            batch_amps.append(amp_val)

        # Return the batch of amplitudes stacked as a tensor
        return torch.stack(batch_amps)

In [178]:
from vmc_torch.fermion_utils import generate_random_fpeps
from vmc_torch.hamiltonian import spinful_Fermi_Hubbard_square_lattice
import jax

Lx = Ly = 6
D = 4
symmetry = 'Z2'
N_f = int(Lx*Ly)
chi = -1
dtype=torch.float64
peps = generate_random_fpeps(Lx, Ly, D=D, seed=2, symmetry=symmetry, Nf=N_f, spinless=False)[0]
peps.apply_to_arrays(lambda x: torch.tensor(x, dtype=dtype))
Ham = spinful_Fermi_Hubbard_square_lattice(Lx, Ly, 1, 8, N_f, pbc=False, n_fermions_per_spin=(N_f//2, N_f//2))
random_config = Ham.hilbert.random_state(key=jax.random.PRNGKey(2))
random_config = torch.tensor(random_config, dtype=dtype)
model = fTN_backflow_attn_conv_Model(peps, max_bond=chi, embedding_dim=16, attention_heads=4, nn_final_dim=4, nn_eta=1e-3, dtype=dtype)
model1 = fTN_backflow_attn_Tensorwise_Model_v1(peps, max_bond=chi, embedding_dim=16, attention_heads=4, nn_final_dim=4, nn_eta=1.0, dtype=dtype)

In [179]:
model(random_config), model1(random_config)
model.num_params, model1.num_params

(121120, 146464)