In [2]:
import torch
print(torch.__version__)
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

1.11.0


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x
    
model = LeNet().to(device=device)

In [5]:
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[ 0.0666, -0.2620,  0.1074],
          [-0.1598,  0.1049,  0.1527],
          [-0.2361,  0.2298,  0.3127]]],


        [[[ 0.0284, -0.2753,  0.1269],
          [ 0.1618,  0.2772, -0.0157],
          [-0.1167,  0.1997,  0.2531]]],


        [[[ 0.2158,  0.0095,  0.1519],
          [ 0.2447,  0.1981,  0.1379],
          [ 0.1105,  0.1449, -0.2628]]],


        [[[-0.1802, -0.2523, -0.3006],
          [ 0.1046, -0.3260, -0.1194],
          [-0.0501, -0.1159,  0.1437]]],


        [[[ 0.0746,  0.0294, -0.1666],
          [ 0.0631, -0.0391, -0.2990],
          [-0.1873,  0.0507,  0.2351]]],


        [[[-0.2670,  0.0611, -0.2217],
          [ 0.0720,  0.3183, -0.0402],
          [-0.0613,  0.2406, -0.0184]]]], requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.0111, -0.0594,  0.1493, -0.2532, -0.1392,  0.0758],
       requires_grad=True))]


In [6]:
print("named_buffers: ",list(module.named_buffers()))
print("weight: ", module.weight)

named_buffers:  []
weight:  Parameter containing:
tensor([[[[ 0.0666, -0.2620,  0.1074],
          [-0.1598,  0.1049,  0.1527],
          [-0.2361,  0.2298,  0.3127]]],


        [[[ 0.0284, -0.2753,  0.1269],
          [ 0.1618,  0.2772, -0.0157],
          [-0.1167,  0.1997,  0.2531]]],


        [[[ 0.2158,  0.0095,  0.1519],
          [ 0.2447,  0.1981,  0.1379],
          [ 0.1105,  0.1449, -0.2628]]],


        [[[-0.1802, -0.2523, -0.3006],
          [ 0.1046, -0.3260, -0.1194],
          [-0.0501, -0.1159,  0.1437]]],


        [[[ 0.0746,  0.0294, -0.1666],
          [ 0.0631, -0.0391, -0.2990],
          [-0.1873,  0.0507,  0.2351]]],


        [[[-0.2670,  0.0611, -0.2217],
          [ 0.0720,  0.3183, -0.0402],
          [-0.0613,  0.2406, -0.0184]]]], requires_grad=True)
