import torch import torch.nn as nn import torch.nn.functional as F import torch_tensorrt as torchtrt import torch_tensorrt.logging as logging import torchvision logging.set_reportable_log_level(logging.Level.Graph) torch.manual_seed(0) DEVICE = torch.device("cuda:0") SHAPE = (1, 1) class Model(torch.nn.Module): def __init__(self): super().__init__() self.lin = nn.Linear(1, 1) def forward(self, a): out = self.lin(a) tril = torch.zeros(1, 1, 1, device=a.device, dtype=out.dtype) indices = torch.tril_indices(1, 1) tril[:, indices[0], indices[1]] = out return tril class Deformable_Convolution(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, offset_groups=1): super(Deformable_Convolution, self).__init__() offset_channels = 2 * kernel_size * kernel_size self.conv2d_offset = torch.nn.Conv2d( in_channels, offset_channels * offset_groups, kernel_size=kernel_size, stride=stride, padding=dilation, dilation=dilation, ) self.conv2d = torchvision.ops.DeformConv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=dilation, dilation=dilation, groups=groups, bias=False ) def forward(self, x): offset = self.conv2d_offset(x) return self.conv2d(x, offset) if __name__ == "__main__": tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE) model = Model().eval().to(DEVICE) with torch.inference_mode(): out = model(tensor) print(f"Model: {out}") model_trt = torchtrt.compile( model, inputs=[torchtrt.Input(shape=SHAPE), ], enabled_precisions={torch.float}, truncate_long_and_double=True ) with torch.inference_mode(): out_trt = model(tensor) print(f"Model TRT: {out_trt}") assert torch.max(torch.abs(out - out_trt)) < 1e-6 SHAPE2 = (1,3,4,4) tensor2 = torch.randn(SHAPE2, dtype=torch.float32, device=DEVICE) model2 = Deformable_Convolution(SHAPE2[1],SHAPE2[1]).eval().to(DEVICE) with torch.inference_mode(): out2 = model2(tensor2) print(f"Model2: {out2}") model_trt = torchtrt.compile( model2, inputs=[torchtrt.Input(shape=SHAPE2), ], enabled_precisions={torch.float}, truncate_long_and_double=True ) with torch.inference_mode(): out_trt2 = model2(tensor) print(f"Model TRT2: {out_trt2}") assert torch.max(torch.abs(out2 - out_trt2)) < 1e-6