In [1]:
import torch
import torch.nn as nn

class MaskedLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(MaskedLinear, self).__init__(in_features, out_features, bias)
        self.register_buffer('weight_mask', torch.ones(out_features, in_features))

    def forward(self, input):
        masked_weight = self.weight * self.weight_mask
        return nn.functional.linear(input, masked_weight, self.bias)

# Create a simple model with a masked linear layer
model = nn.Sequential(
    MaskedLinear(3, 4),
    nn.ReLU(),
    nn.Linear(4, 2)
)

# Set a specific weight mask for the masked linear layer
new_mask = torch.tensor([
    [1, 0, 1],
    [1, 1, 0],
    [0, 1, 1],
    [1, 1, 1]
])
model[0].weight_mask.copy_(new_mask)

tensor([[1., 0., 1.],
        [1., 1., 0.],
        [0., 1., 1.],
        [1., 1., 1.]])

In [7]:
model[0].weight_mask

tensor([[1., 0., 1.],
        [1., 1., 0.],
        [0., 1., 1.],
        [1., 1., 1.]])

In [16]:
import torch
import torch.nn as nn

class SimpleFeedforward(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleFeedforward, self).__init__()
        self.hidden = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.output = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.hidden(x)
        x = self.relu(x)
        x = self.output(x)
        return x

input_size = 3
hidden_size = 4
output_size = 2

model = SimpleFeedforward(input_size, hidden_size, output_size)
model_state_dict = model.state_dict()
print(model_state_dict)

OrderedDict([('hidden.weight', tensor([[-0.0608,  0.0539,  0.5331],
        [-0.4810,  0.1233,  0.0555],
        [-0.4860,  0.0546,  0.0107],
        [-0.2816,  0.4855, -0.0054]])), ('hidden.bias', tensor([ 0.1430,  0.2492, -0.4270,  0.0581])), ('output.weight', tensor([[ 0.3120, -0.1003,  0.1688,  0.1702],
        [-0.3985, -0.1890, -0.0265,  0.1303]])), ('output.bias', tensor([0.4874, 0.4113]))])


In [18]:
list(model.parameters())

[Parameter containing:
 tensor([[-0.0608,  0.0539,  0.5331],
         [-0.4810,  0.1233,  0.0555],
         [-0.4860,  0.0546,  0.0107],
         [-0.2816,  0.4855, -0.0054]], requires_grad=True),
 Parameter containing:
 tensor([ 0.1430,  0.2492, -0.4270,  0.0581], requires_grad=True),
 Parameter containing:
 tensor([[ 0.3120, -0.1003,  0.1688,  0.1702],
         [-0.3985, -0.1890, -0.0265,  0.1303]], requires_grad=True),
 Parameter containing:
 tensor([0.4874, 0.4113], requires_grad=True)]

In [19]:
import torch.optim as optim
optim.SGD(model.parameters(), lr=1e-10).zero_grad()

In [20]:
list(model.parameters())

[Parameter containing:
 tensor([[-0.0608,  0.0539,  0.5331],
         [-0.4810,  0.1233,  0.0555],
         [-0.4860,  0.0546,  0.0107],
         [-0.2816,  0.4855, -0.0054]], requires_grad=True),
 Parameter containing:
 tensor([ 0.1430,  0.2492, -0.4270,  0.0581], requires_grad=True),
 Parameter containing:
 tensor([[ 0.3120, -0.1003,  0.1688,  0.1702],
         [-0.3985, -0.1890, -0.0265,  0.1303]], requires_grad=True),
 Parameter containing:
 tensor([0.4874, 0.4113], requires_grad=True)]

In [21]:
import torch
from torch.utils.data import Dataset, DataLoader

class RandomTensorDataset(Dataset):
    def __init__(self, num_samples=100, input_size=10):
        self.num_samples = num_samples
        self.input_size = input_size
        self.data = torch.randn(num_samples, input_size)
        self.targets = torch.randint(0, 2, (num_samples,))

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        return self.data[index], self.targets[index]

# Create a custom dataset with 100 samples, each with 10 features
random_dataset = RandomTensorDataset(num_samples=100, input_size=10)

# Create a DataLoader with a batch size of 20
dummy_loader = DataLoader(random_dataset, batch_size=20, shuffle=True)

# Loop through the DataLoader
for batch_idx, (in_tensor, target) in enumerate(dummy_loader):
    print(f"Batch {batch_idx}:")
    print(f"Input tensor (in_tensor) shape: {in_tensor.shape}")
    print(f"Target tensor shape: {target.shape}")
    print()


Batch 0:
Input tensor (in_tensor) shape: torch.Size([20, 10])
Target tensor shape: torch.Size([20])

Batch 1:
Input tensor (in_tensor) shape: torch.Size([20, 10])
Target tensor shape: torch.Size([20])

Batch 2:
Input tensor (in_tensor) shape: torch.Size([20, 10])
Target tensor shape: torch.Size([20])

Batch 3:
Input tensor (in_tensor) shape: torch.Size([20, 10])
Target tensor shape: torch.Size([20])

Batch 4:
Input tensor (in_tensor) shape: torch.Size([20, 10])
Target tensor shape: torch.Size([20])



In [80]:
import torch
import torch.nn as nn
import torch.optim as optim

def flatten_tensor_list(tensor_list):
    return torch.cat([t.view(-1) for t in tensor_list])


In [87]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Create a simple model
model = SimpleModel()

# Create random input and target tensors
in_tensor = torch.randn(1, 10)
target = torch.tensor([1])


In [88]:
criterion = nn.CrossEntropyLoss()


In [89]:
output = model(in_tensor)
loss = criterion(output, target)


In [84]:
def compute_sample_fisher(loss, return_outer_product=True):
    # The original function used self, but in our example,
    # we removed self and replaced with the model and criterion we created earlier
    _weight_only = False
    _modules = [model.fc1, model.fc2]

    ys = loss
    params = []
    m_idx = 0
    for module in _modules:
        m_idx += 1
        for name, param in module.named_parameters():
            if _weight_only and 'bias' in name:
                continue
            else:
                params.append(param)

    grads = torch.autograd.grad(ys, params)
    grads = flatten_tensor_list(grads)
    params = flatten_tensor_list(params)

    gTw = params.T @ grads

    if not return_outer_product:
        return grads, None, gTw, params
    else:
        return torch.ger(grads, grads), grads, gTw, params

result = compute_sample_fisher(loss,return_outer_product=False)

In [86]:
result[0]

tensor([ 0.0898, -0.1266,  0.0653,  0.0692, -0.2971, -0.1322,  0.0081, -0.2103,
        -0.2309, -0.2419,  0.0643, -0.0907,  0.0468,  0.0496, -0.2127, -0.0947,
         0.0058, -0.1506, -0.1653, -0.1733,  0.0148, -0.0208,  0.0107,  0.0114,
        -0.0488, -0.0217,  0.0013, -0.0346, -0.0379, -0.0398, -0.0074,  0.0105,
        -0.0054, -0.0057,  0.0246,  0.0110, -0.0007,  0.0174,  0.0191,  0.0201,
        -0.0000,  0.0000, -0.0000, -0.0000,  0.0000,  0.0000, -0.0000,  0.0000,
         0.0000,  0.0000, -0.2876, -0.2060, -0.0473,  0.0239,  0.0000,  0.3242,
         0.0914,  0.0735,  0.2146,  0.0000, -0.3242, -0.0914, -0.0735, -0.2146,
        -0.0000,  0.4008, -0.4008])

In [70]:
import numpy as np
np.random.choice(10,10)

array([6, 9, 9, 2, 2, 8, 9, 5, 8, 2])

In [97]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

def flatten_tensor_list(tensors):
    return torch.cat([t.view(-1) for t in tensors])

# Create a small dataset with 4 samples
in_tensor = torch.randn(4, 10)
target = torch.tensor([1, 0, 1, 0])
dataset = TensorDataset(in_tensor, target)

def compute_sample_fisher(loss, return_outer_product=True):
    _weight_only = False
    _modules = [model.fc1, model.fc2]

    ys = loss
    params = []
    m_idx = 0
    for module in _modules:
        m_idx += 1
        for name, param in module.named_parameters():
            if _weight_only and 'bias' in name:
                continue
            else:
                params.append(param)

    grads = torch.autograd.grad(ys, params)
    grads = flatten_tensor_list(grads)
    params = flatten_tensor_list(params)

    gTw = params.T @ grads

    if not return_outer_product:
        return grads, None, gTw, params
    else:
        return torch.ger(grads, grads), grads, gTw, params

def _compute_wgH(model, dummy_loader, device, args, _fisher_mini_bsz):
    model = model.to(device)

    goal = args.fisher_subsample_size

    assert len(subset_indices) == goal * args.fisher_mini_bsz

    Gs = []

    if args.disable_log_soft:
        criterion = torch.nn.functional.cross_entropy
    else:
        criterion = nn.functional.nll_loss

    num_batches = 0
    num_samples = 0

    for in_tensor, target in dummy_loader:
        in_tensor, target = in_tensor.to(device), target.to(device)
        output = model(in_tensor)
        loss = criterion(output, target)

        g, _, _, _ = compute_sample_fisher(loss, return_outer_product=False)
        Gs.append(g[None, :].detach().cpu())

        num_batches += 1
        num_samples += _fisher_mini_bsz
        if num_samples == goal * args.fisher_mini_bsz:
            break

    grads = torch.cat(Gs, 0)
    return grads

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleModel()

subset_indices = list(range(len(dataset)))
args = lambda: None
args.fisher_subsample_size = 2
args.fisher_mini_bsz = len(dataset) // args.fisher_subsample_size
args.disable_log_soft = True

dummy_loader = DataLoader(dataset, batch_size=args.fisher_mini_bsz, sampler=SubsetRandomSampler(subset_indices))

grads = _compute_wgH(model, dummy_loader, device, args, args.fisher_mini_bsz)
print("Gradients tensor:\n", grads)


Gradients tensor:
 tensor([[ 0.0021, -0.0111,  0.0155, -0.0009,  0.0021, -0.0066, -0.0117, -0.0051,
          0.0031, -0.0150, -0.0392,  0.2069, -0.2899,  0.0173, -0.0387,  0.1233,
          0.2182,  0.0956, -0.0586,  0.2817,  0.0225, -0.1186,  0.1662, -0.0099,
          0.0222, -0.0707, -0.1251, -0.0548,  0.0336, -0.1616,  0.0164, -0.0865,
          0.1212, -0.0072,  0.0162, -0.0516, -0.0912, -0.0400,  0.0245, -0.1178,
         -0.0059,  0.0310, -0.0435,  0.0026, -0.0058,  0.0185,  0.0327,  0.0143,
         -0.0088,  0.0422,  0.0014, -0.0255,  0.0146,  0.0107, -0.0038, -0.0780,
         -0.1699, -0.2387,  0.1845, -0.0218,  0.0780,  0.1699,  0.2387, -0.1845,
          0.0218,  0.0477, -0.0477],
        [-0.0062,  0.0086,  0.0154,  0.0089, -0.0101, -0.0107,  0.0055,  0.0103,
          0.0081, -0.0062,  0.1167, -0.1603, -0.2877, -0.1673,  0.1889,  0.2007,
         -0.1033, -0.1936, -0.1521,  0.1161, -0.0669,  0.0919,  0.1650,  0.0960,
         -0.1083, -0.1151,  0.0593,  0.1110,  0.0872,

In [101]:
(grads.T @ grads).shape

torch.Size([67, 67])