In [1]:
import torch
from torch.utils.cpp_extension import load
import os

os.environ['CUDA_LAUNCH_BLOCKING']="1"

torch integer data flow 는 gpu에서 사용 불가

In [2]:
# int_data = torch.randint(0,255,(1,3,24,24), dtype=torch.uint8)
# weight = torch.randint(0,255,(1,3,3,3), dtype=torch.uint8)

# # b = torch.nn.functional.conv2d(int_data.cuda(), weight=weight.cuda(),stride=1)
# b = torch.nn.functional.conv2d(int_data, weight=weight,stride=1,bias=None, padding=1, dtype=torch.uint8)
# print(b.shape)

In [3]:
import int8mm_cuda
from torch.nn.modules import Module

class IntLinear(Module):
    def __init__(self, in_channels, out_channels):
        super(IntLinear,self).__init__()
        self.weight = torch.randint(-127,127,(out_channels, in_channels), dtype=torch.int8)

    def forward(self,x):
        # weight [OUT, IN} - > [IN, OUT]
        # input [BATCH, IN]
        y = int8mm_cuda.int8_mm(x,self.weight.transpose(1,0).contiguous())
        y = (y > 127).int()*5
        y = y.type(torch.int8)
        return y
    
    def cuda(self):
        self.weight = self.weight.cuda()
        

mm = IntLinear(4,12)
x = torch.randint(-127,127,(1,4), dtype=torch.int8).cuda()
print(x.dtype, mm.weight.dtype)
with torch.no_grad():
    mm.cuda()
    y = mm(x)
print(x)
print(mm.weight)
print(y)

torch.int8 torch.int8
tensor([[ -40, -114, -111,   91]], device='cuda:0', dtype=torch.int8)
tensor([[-109,  -23, -116,   -8],
        [  82,  -50, -109,  -23],
        [ -49,  107,   29,  -39],
        [ -71,   89,   93, -109],
        [  39,  -66, -118,   85],
        [-110, -104,   39,   20],
        [   7,  123,  -25, -116],
        [  57,   30,   14,   40],
        [  50,   28,   66,    0],
        [  52,   54, -110,  -85],
        [  -4,  -41,   40,    4],
        [-118,  -17,  -81, -127]], device='cuda:0', dtype=torch.int8)
tensor([[5, 5, 0, 0, 5, 5, 0, 0, 0, 0, 5, 5]], device='cuda:0',
       dtype=torch.int8)


In [4]:
import numpy as np 
print(mm.weight.transpose(1,0),end="\n\n")
x_data = x.detach().cpu().numpy()
mm_data = mm.weight.detach().cpu().numpy()
print(f"x data - {x_data}\n")
print(f"mm data - {mm_data}\nmm Trans - {mm_data.T}\n")
y = x_data @ mm_data.T
print(f"y - {y}")

tensor([[-109,   82,  -49,  -71,   39, -110,    7,   57,   50,   52,   -4, -118],
        [ -23,  -50,  107,   89,  -66, -104,  123,   30,   28,   54,  -41,  -17],
        [-116, -109,   29,   93, -118,   39,  -25,   14,   66, -110,   40,  -81],
        [  -8,  -23,  -39, -109,   85,   20, -116,   40,    0,  -85,    4, -127]],
       device='cuda:0', dtype=torch.int8)

x data - [[ -40 -114 -111   91]]

mm data - [[-109  -23 -116   -8]
 [  82  -50 -109  -23]
 [ -49  107   29  -39]
 [ -71   89   93 -109]
 [  39  -66 -118   85]
 [-110 -104   39   20]
 [   7  123  -25 -116]
 [  57   30   14   40]
 [  50   28   66    0]
 [  52   54 -110  -85]
 [  -4  -41   40    4]
 [-118  -17  -81 -127]]
mm Trans - [[-109   82  -49  -71   39 -110    7   57   50   52   -4 -118]
 [ -23  -50  107   89  -66 -104  123   30   28   54  -41  -17]
 [-116 -109   29   93 -118   39  -25   14   66 -110   40  -81]
 [  -8  -23  -39 -109   85   20 -116   40    0  -85    4 -127]]

y - [[ -70 -118 -110  100  -83  -77  -67  

Pooling layer


In [5]:
import int8pool_cuda

class IntPool(Module):
    def __init__(self,kernel_size = 2, stride = 2, padding=0, mode=0):
        super(IntPool,self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.mode = mode
    
    def forward(self,x):
        y = int8pool_cuda.int8_pool(x,self.kernel_size, self.stride, self.padding, self.mode)
        # y = (y > 10).int()*5
        # y = y.type(torch.int8)
        return y

pool = IntPool()
x = torch.randint(0, 127,(4,32,32,4), dtype=torch.int8).cuda()

In [6]:
with torch.no_grad():
    y = pool(x)

In [7]:
print(f"x - {x.shape} \n{x}\n")
print(f"y - {y.shape}\n{y}")

x - torch.Size([4, 32, 32, 4]) 
tensor([[[[  9,  91,  37,  72],
          [113,  69,  47, 112],
          [ 47,  47,  44, 119],
          ...,
          [107,  86,  64,  96],
          [126,  18,  11, 101],
          [ 30,  67,  48,  93]],

         [[124,  86,   2,  66],
          [ 84,  83, 106,  69],
          [ 13,  32,  30, 109],
          ...,
          [ 84,  96,  70, 115],
          [ 83,  17,  15,  21],
          [121,  10, 109,  77]],

         [[ 37,  95,  27,  23],
          [114,  76,  85,  75],
          [ 86,  63,  26,  99],
          ...,
          [ 11, 107,  39,  36],
          [ 96,  61, 108,  43],
          [109,   8,   4,  82]],

         ...,

         [[ 40,  20,   9,  10],
          [ 72,  62,   8,  77],
          [ 54,  97,  41,  57],
          ...,
          [ 94,   3,  26,  50],
          [  4,  71,  82,  27],
          [ 99,  84,  17,  52]],

         [[ 65,  25,  88,  88],
          [ 63,   8,  30,  50],
          [115,  16,  61,  76],
          ...,
      

In [8]:
avg_pool = IntPool(mode=1)
x = torch.randint(0, 127,(4,32,32,4), dtype=torch.int8).cuda()
with torch.no_grad():
    y = pool(x)
print(f"x - {x.shape} \n{x}\n")
print(f"y - {y.shape}\n{y}")

x - torch.Size([4, 32, 32, 4]) 
tensor([[[[ 39,  64,  95,  30],
          [ 15,  92,  13,  87],
          [ 13,  23,  62,  39],
          ...,
          [ 17,  56, 112,  92],
          [ 75,  32,  38,  86],
          [ 30, 106,  75,   1]],

         [[ 15,  87,  74,  62],
          [107,  61,  65,  86],
          [ 46,  75,  81,  44],
          ...,
          [ 70,  54,  67,  67],
          [ 15,   5,  54, 109],
          [ 51,  65,  61,  68]],

         [[ 28,  20,   8,  35],
          [ 32,  49, 106, 112],
          [ 97,  83,  21,  57],
          ...,
          [110,  73,  41,  64],
          [ 86,  34,  44, 115],
          [ 95,  97,  31,  24]],

         ...,

         [[ 96,  76,  99,  19],
          [114,  67,  32,  37],
          [ 72,  57,  21, 126],
          ...,
          [ 57,  17, 101, 123],
          [  2, 114, 102,  86],
          [ 36,  17, 102,  11]],

         [[116,  47, 119, 110],
          [ 32,  15,  69,  26],
          [122,  55,  14, 123],
          ...,
      

pytorch conv layer parmeter shape 보기

In [9]:
import torchvision
model = torchvision.models.vgg.vgg16(pretrained=True)
print(model.features[0].weight.shape)
c = torch.nn.Conv2d(4,12,3,1,1)
print(c.weight.shape) # NCHW



torch.Size([64, 3, 3, 3])
torch.Size([12, 4, 3, 3])


In [10]:
import cutlassconv
from torch.nn.modules import Module

class IntConv2d(Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride =1, padding =1):
        super(IntConv2d,self).__init__()
        self.weight = torch.randint(-127,127,(out_channels, kernel_size, kernel_size, in_channels), dtype=torch.int8)
        self.stride = stride
        self.padding = padding

    def forward(self,x):
        # trans_weight = torch.flip(self.weight,[1,2]).transpose(0,3).contiguous()
        # trans_weight = self.weight.permute(0,2,3,1).contiguous()
        trans_weight = self.weight
        return cutlassconv.int8_conv(x,trans_weight)
    
    def cuda(self):
        self.weight = self.weight.cuda()
        
## cutlass는 16의 배수만
input_channel= 16
conv = IntConv2d(input_channel,32,3,1,1)
print(conv.weight.shape)
x = torch.randint(0,127,(1,32,32,input_channel), dtype=torch.int8).cuda()


torch.Size([32, 3, 3, 16])


In [11]:
with torch.no_grad():
    conv.cuda()
    y = conv(x)


In [12]:
import numpy as np 
print(f"x data - {x.shape} \n{x_data}\n")
print(f"conv data - {conv.weight.shape}\n{conv.weight}\n")
print(f"y data - {y.shape}\n{y}")

x data - torch.Size([1, 32, 32, 16]) 
[[ -40 -114 -111   91]]

conv data - torch.Size([32, 3, 3, 16])
tensor([[[[  74,  -19,  -82,  ...,   48,   38,  -46],
          [  33, -114,   -1,  ...,  -23,   93,  -68],
          [ 101,  -54,  112,  ...,   20,  -87,  -80]],

         [[ -64,   21,    5,  ..., -120, -114,   19],
          [  55,   90,  -38,  ...,   18,   52,   -2],
          [  93,    4,  106,  ...,   10,  -38,   24]],

         [[ 111,  -34,   96,  ...,  -94,   67,   78],
          [-102,   56,  -19,  ...,  -40,   -7,   -1],
          [-109,   63,   51,  ...,  -46,   91,   54]]],


        [[[ -29,  -32,   -8,  ...,   51,  -99,  -57],
          [  -3,  -48,  -47,  ...,  116,  115, -111],
          [  90,  119,  124,  ..., -127,   28,   63]],

         [[ -92,  -98,   82,  ...,   47,  -49,   31],
          [ -26,  -92,   -3,  ...,  -90, -112,  -65],
          [-126,   -3,  -63,  ...,  -33,   -4,   91]],

         [[  31,   94,  121,  ...,  -37,   78,  -46],
          [ -41,   33,

In [13]:
import int8conv_cuda
from torch.nn.modules import Module

class IntConv2d(Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride =1, padding =1):
        super(IntConv2d,self).__init__()
        self.weight = torch.randint(0,127,(out_channels, kernel_size, kernel_size, in_channels), dtype=torch.int8)
        self.stride = stride
        self.padding = padding

    def forward(self,x):
        # trans_weight = torch.flip(self.weight,[1,2]).transpose(0,3).contiguous()
        # trans_weight = self.weight.permute(0,2,3,1).contiguous()
        trans_weight = self.weight
        y = int8conv_cuda.int8_conv(x,trans_weight,self.stride, self.padding,1)
        y = (y > 127).int()*5
        y = y.type(torch.int8)
        return y
    
    def cuda(self):
        self.weight = self.weight.cuda()
    
# cudnn은 4의 배수만
input_channel= 4
conv = IntConv2d(input_channel,32,3,1,1)
print(conv.weight.shape)
x = torch.randint(0,127,(1,32,32,input_channel), dtype=torch.int8).cuda()


torch.Size([32, 3, 3, 4])


In [14]:
with torch.no_grad():
    conv.cuda()
    y = conv(x)
print(y.shape, y.device)

torch.Size([1, 32, 32, 32]) cuda:0


In [15]:
import numpy as np 
conv_data = conv.weight.detach().cpu().numpy()
x_data = x.detach().cpu().numpy()
y_data = y.detach().cpu().numpy()
print(f"x data - {x.shape} \n{x_data}\n")
print(f"conv data - {conv_data.shape}\n{conv_data}\n")
print(f"y data - {y.shape}\n{y}")

x data - torch.Size([1, 32, 32, 4]) 
[[[[ 57  64  79   8]
   [ 92 119  21  44]
   [ 19 113  61  73]
   ...
   [ 91  27  64 124]
   [ 26 109  37  97]
   [111   2  70  82]]

  [[ 20  16  61  39]
   [ 15 108 104  61]
   [ 18  77 116  52]
   ...
   [ 86  64  89 122]
   [ 73  78  15 106]
   [  8 113  50 103]]

  [[ 82  98  62  69]
   [  5   7  64 113]
   [ 34 112  71  16]
   ...
   [ 49  26  12  32]
   [106  30  68  62]
   [ 42 102   2  31]]

  ...

  [[121  84  90  54]
   [ 41  97  25  55]
   [120  96  76 103]
   ...
   [ 68  72  79 103]
   [ 46  25  62 104]
   [ 87  47  74 121]]

  [[ 17 104  51  65]
   [ 91 102   0 124]
   [ 98   5  37 101]
   ...
   [115  97  33 100]
   [ 85 124 113  44]
   [ 36 105 119  18]]

  [[ 61  39 110  52]
   [ 57  15  59  77]
   [119 115  96  97]
   ...
   [ 32 107  75  74]
   [ 12 111  62  61]
   [ 23 112  85  88]]]]

conv data - (32, 3, 3, 4)
[[[[ 45  18  29  96]
   [126  77  63  76]
   [103  31 107  76]]

  [[101  19  33  22]
   [ 86  53  64  81]
   [ 58  87

VGG 모델 테스트

In [16]:
import torch.nn as nn

class VGG(nn.Module):
    def __init__(
        self, features: nn.Module, num_classes: int = 100, dropout: float = 0.5) -> None:
        super().__init__()
        self.features = features
        self.avgpool = IntPool(7,1,0,1)
        self.classifier = nn.Sequential(
            IntLinear(512,4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            IntLinear(4096,4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            IntLinear(4096,num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
    def cuda(self):
        for layer in model.modules():
            if 'Int' in str(type(layer)):
                layer.cuda()

def make_layers(cfg, batch_norm: bool = False) -> nn.Sequential:
    layers = []
    in_channels = 4
    for vs in cfg:
        for v in vs:
            v = int(v)
            conv2d = IntConv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU()]
            in_channels = v
        layers += [IntPool(kernel_size=2, stride=2)]
    return nn.Sequential(*layers)


cfgs = {
    "D": [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512],[512, 512, 512]],
    # "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
}

def int_vgg(cfg: str, **kwargs) -> VGG:
    model = VGG(make_layers(cfgs[cfg]), **kwargs)
    return model

In [17]:
modules = []
before_l = []
after_l = []
hooks = []

def hook_fn(module, input, output):
    modules.append(module)
    before_l.append(input[0])
    after_l.append(output)

def add_forward_hook(net, hooks):
    for name, layer in net._modules.items():
        if isinstance(layer, nn.Sequential) or isinstance(layer, torchvision.models.vgg.VGG):
            add_forward_hook(layer, hooks)
        else:
            hook = layer.register_forward_hook(hook_fn)
            hooks.append(hook)
            
    return hooks

def remove_forward_hook(hooks):
    for i in hooks:
        i.remove()
# out = model((torch.randn(1,3,32,32)))

In [18]:
# class test_module(Module):
#     def __init__(self,num_classes= 100):
#         super(test_module,self).__init__()
#         self.layers = nn.Sequential(
#             IntLinear(512,4096),
#             nn.ReLU(),
#             nn.Dropout(0.5),
#             IntLinear(4096,4096),
#             nn.ReLU(),
#             nn.Dropout(0.5),
#             IntLinear(4096,num_classes),
#         )
#     def forward(self,x):
#         for l in self.layers:
#             x = l(x)
#             print(x.shape, x.dtype)
#         return x
#     def cuda(self):
#         for layer in model.modules():
#             if 'Int' in str(type(layer)):
#                 layer.cuda()

# model = test_module()
# x = torch.randint(-127,127,(1,512)).cuda()
# model.eval()
# model.cuda()
# print(model.layers[0].weight)
# print(x.dtype)
# with torch.no_grad():
#     y = model(x)



In [19]:
model = int_vgg("D")
model.eval()
hooks = add_forward_hook(model, hooks)
# remove hook, hook works at once
remove_forward_hook(hooks)
model.cuda()
with torch.no_grad():
    x = torch.randint(-127,127,(1,224,224,4), dtype=torch.int8).cuda()
    y = model(x)
    print(len(hooks), len(modules), len(before_l), len(after_l))
    remove_forward_hook(hooks)
    hooks=[]
print(y.dtype, y.shape, y)
    

39 0 0 0
torch.int8 torch.Size([1, 100]) tensor([[5, 5, 0, 0, 0, 0, 5, 0, 0, 5, 5, 0, 5, 0, 0, 5, 5, 0, 5, 5, 5, 0, 0, 5,
         5, 0, 0, 5, 0, 0, 5, 0, 0, 5, 0, 0, 0, 5, 0, 0, 0, 0, 0, 5, 0, 5, 0, 5,
         5, 0, 5, 5, 5, 0, 0, 0, 0, 5, 0, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 5, 5,
         0, 5, 5, 5, 0, 0, 0, 0, 5, 0, 5, 0, 0, 5, 5, 0, 0, 5, 5, 5, 0, 0, 0, 5,
         5, 0, 5, 5]], device='cuda:0', dtype=torch.int8)


In [20]:
i = torch.randint(-127, 127,(1,4,4,3), dtype=torch.int8).cuda()
lay = nn.Dropout(0.5)

lay.cuda()
with torch.no_grad():
    lay.eval()
    y = lay(i)
    k = nn.functional.relu(y)
print(y)
print(k)

tensor([[[[ -43, -119,   21],
          [  14,  110,  -41],
          [-107,   15, -115],
          [-111,   54,  -64]],

         [[  16,  -11,   95],
          [-108,  -73,  -38],
          [-105,  -15,  -51],
          [  98,  -34,   41]],

         [[ -75,  -17,  -16],
          [  21,   94,  -50],
          [  30, -124,    8],
          [  39,  -60, -115]],

         [[-104,  120, -116],
          [-115,   66,  110],
          [ -40,  -91,  -93],
          [   6,  -79,  105]]]], device='cuda:0', dtype=torch.int8)
tensor([[[[  0,   0,  21],
          [ 14, 110,   0],
          [  0,  15,   0],
          [  0,  54,   0]],

         [[ 16,   0,  95],
          [  0,   0,   0],
          [  0,   0,   0],
          [ 98,   0,  41]],

         [[  0,   0,   0],
          [ 21,  94,   0],
          [ 30,   0,   8],
          [ 39,   0,   0]],

         [[  0, 120,   0],
          [  0,  66, 110],
          [  0,   0,   0],
          [  6,   0, 105]]]], device='cuda:0', dtype=torch.int8)


In [21]:
from models import vgg

model = vgg.int_vgg16("D")
x = torch.randint(-128,127,(1,224,224,4),dtype=torch.int8).cuda()
model.eval()
model.cuda()
with torch.no_grad():
    y = model(x)
    print(y.shape)
print(y)

torch.int32
torch.int32
torch.int32
torch.int32
torch.int32
torch.int32
torch.int32
torch.int32
torch.int32
torch.int32
torch.int32
torch.int32
torch.int32
torch.Size([1, 100])
tensor([[ 1215963,   237793,   591355,   694649,    75676,  -523889,  -540199,
          -558748,  -531662,  -290858,     8800,  -858874,   491848,  -475596,
          -555899,   114556,   -20991,  -138834,  -372881,  -115454,   988935,
          -477697,   120380,  -369514,   193641,   421465,  -416288,   295891,
           852482,   380552, -1040677,   155962,  -770499,  -325930,  -592913,
           717186,    16002,  -448430,  -358064, -1259084,  -212092,  1522873,
         -1213738,  -419031,  -658384,  -139878,   201685,  -340574,   -28981,
           987099,  -424457, -1613387,   674524,   895133,  -344503,     -621,
           459730,  -401868,  -466956,  -328991,  -412663,   433640,  -881373,
           231411,  -890765,   -26702,   379205,   567023,  1050320, -1022450,
           680504,  -623356,   39

In [22]:
x = torch.ones((3,3), dtype=torch.int64)*(2**18+1)
print(x)
x = x + 2**15
print(x)
x = torch.clamp(x, min=0, max=2**16-1)
print(x)
x = torch.sqrt(x)
print(x)
x = x-128
print(x)
x = x.type(torch.int8)
print(x)

tensor([[262145, 262145, 262145],
        [262145, 262145, 262145],
        [262145, 262145, 262145]])
tensor([[294913, 294913, 294913],
        [294913, 294913, 294913],
        [294913, 294913, 294913]])
tensor([[65535, 65535, 65535],
        [65535, 65535, 65535],
        [65535, 65535, 65535]])
tensor([[255.9980, 255.9980, 255.9980],
        [255.9980, 255.9980, 255.9980],
        [255.9980, 255.9980, 255.9980]])
tensor([[127.9980, 127.9980, 127.9980],
        [127.9980, 127.9980, 127.9980],
        [127.9980, 127.9980, 127.9980]])
tensor([[127, 127, 127],
        [127, 127, 127],
        [127, 127, 127]], dtype=torch.int8)


In [23]:
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]]).reshape(1,3,3,1)
print(a.shape, a,sep='\n')
a = a.repeat(1,1,1,4)
print(a.shape, a,sep='\n')

torch.Size([1, 3, 3, 1])
tensor([[[[1],
          [2],
          [3]],

         [[4],
          [5],
          [6]],

         [[7],
          [8],
          [9]]]])
torch.Size([1, 3, 3, 4])
tensor([[[[1, 1, 1, 1],
          [2, 2, 2, 2],
          [3, 3, 3, 3]],

         [[4, 4, 4, 4],
          [5, 5, 5, 5],
          [6, 6, 6, 6]],

         [[7, 7, 7, 7],
          [8, 8, 8, 8],
          [9, 9, 9, 9]]]])


In [24]:
class FloatConv2d(Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride =1, padding =1):
        super(FloatConv2d,self).__init__()
        self.weight = torch.ones((out_channels, kernel_size, kernel_size, in_channels),dtype=torch.float32)
        self.stride = stride
        self.padding = padding

    def forward(self,x):
        # trans_weight = torch.flip(self.weight,[1,2]).transpose(0,3).contiguous()
        # trans_weight = self.weight.permute(0,2,3,1).contiguous()
        trans_weight = self.weight
        y = int8conv_cuda.float_conv(x,trans_weight,self.stride, self.padding,1)
        return y
    
    def cuda(self):
        self.weight = self.weight.cuda()


input_channel= 4
conv = FloatConv2d(input_channel,1,3,1,1)
conv.cuda()
print(f"conv: {conv.weight.shape}\n{conv.weight}")
x = torch.tensor([[1,2,3],[4,5,6],[7,8,9]],dtype=torch.float32).reshape(1,3,3,1).repeat(1,1,1,input_channel).cuda()
print(f"x : {x.shape} \n{x}")
with torch.no_grad():
    y= conv(x)
    print(f"y : {y.shape}\n{y}")

conv: torch.Size([1, 3, 3, 4])
tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]], device='cuda:0')
x : torch.Size([1, 3, 3, 4]) 
tensor([[[[1., 1., 1., 1.],
          [2., 2., 2., 2.],
          [3., 3., 3., 3.]],

         [[4., 4., 4., 4.],
          [5., 5., 5., 5.],
          [6., 6., 6., 6.]],

         [[7., 7., 7., 7.],
          [8., 8., 8., 8.],
          [9., 9., 9., 9.]]]], device='cuda:0')
y : torch.Size([1, 3, 3, 1])
tensor([[[[ 48.],
          [ 84.],
          [ 64.]],

         [[108.],
          [180.],
          [132.]],

         [[ 96.],
          [156.],
          [112.]]]], device='cuda:0')
