In [1]:
!pip install ninja

Collecting ninja
  Downloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.1 kB)
Downloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (180 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/180.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.7/180.7 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ninja
Successfully installed ninja-1.13.0


In [2]:
import os

# 1. Nsight Systems (nsys) 설치 여부 확인 및 경로 설정
if os.path.exists('/usr/local/cuda/bin/nsys'):
    print("Found nsys in /usr/local/cuda/bin. Adding to PATH...")
    os.environ['PATH'] += ':/usr/local/cuda/bin'
else:
    print("Installing Nsight Systems...")
    # 최신 버전 설치 (apt repository 업데이트)
    !apt-get update -y
    !apt-get install -y nsight-systems-2023.3.3
    # 설치 후 경로 추가 (보통 /usr/local/bin에 생기지만 혹시 모르니)
    os.environ['PATH'] += ':/usr/local/cuda/bin'

# 2. Nsight Compute (ncu) 설치 여부 확인
if os.path.exists('/usr/local/cuda/bin/ncu'):
    print("Found ncu in /usr/local/cuda/bin.")
else:
    print("Installing Nsight Compute...")
    !apt-get install -y nsight-compute-2023.3.1
    os.environ['PATH'] += ':/usr/local/cuda/bin'

print("\n=== Installation Check ===")
!nsys --version
!ncu --version

Installing Nsight Systems...
Hit:1 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Get:4 https://cli.github.com/packages stable InRelease [3,917 B]
Get:5 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:6 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:7 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Get:8 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Get:9 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ Packages [85.0 kB]
Get:10 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [2,361 kB]
Get:11 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease [18.1 kB]
Get:12 https://cli.github.com/packages stable/main amd64 Packages [355 B]
Hit:13 https://ppa.launchp

In [3]:
%%writefile profile_run.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load_inline
import torch.cuda.nvtx as nvtx
import time

# ---------------------------------------------------------
# 1. CUDA Kernel Source (작성하신 코드 그대로)
# ---------------------------------------------------------
cuda_source = """
#include <cuda_runtime.h>
#include <torch/extension.h>

template<int BM, int BN, int BK, int TM, int TN_PHY>
__global__ void sgemm2D(float* Input, int* B, float* C, float* scale, float* zero_point,
    int batch,
    int In_H, int In_W, int In_C,
    int Out_H, int Out_W,
    int K_H, int K_W,
    int Pad_H, int Pad_W,
    int Stride_H, int Stride_W,
    int Dilation_H, int Dilation_W)
{
    int cRow = blockIdx.y;
    int cCol = blockIdx.x;

    int M = batch * Out_H * Out_W;
    int N = gridDim.x * BN;
    int K = In_C * K_H * K_W;

    constexpr int TN_LOG = TN_PHY * 8;

    int threadRow = threadIdx.x / (BN / TN_LOG);
    int threadCol = threadIdx.x % (BN / TN_LOG);

    __shared__ float As[BM * BK];
    __shared__ int Bs[BK * (BN / 8)];

    B += cCol * (BN / 8);
    C += cRow * BM * N + cCol * BN;

    float threadResults[TM * TN_LOG] = {0.0};
    float regM[TM] = {0.0};
    int regN[TN_PHY] = {0};

    // ★ 최적화: K 차원에 대해 입력 A의 값만 순수하게 누적해둘 배열
    float sumA[TM] = {0.0f};

    // (기존의 my_scales, my_zeros 로드 부분 완전히 삭제됨)

    for (int bk = 0; bk < K; bk += BK) {

        for (int idx = threadIdx.x; idx < BK * BM; idx += blockDim.x) {
            int r = idx % BM;
            int c = idx / BM;

            int globalM = cRow * BM + r;
            int curr_k = bk + c;

            float val = 0.0f;
            if (curr_k < K && globalM < M) {
                int area_out = Out_H * Out_W;
                int batch_idx = globalM / area_out;
                int pixel_rem = globalM % area_out;
                int oh = pixel_rem / Out_W;
                int ow = pixel_rem % Out_W;

                int kc = curr_k % K_W;
                int k_rem = curr_k / K_W;
                int kh = k_rem % K_H;
                int ic = k_rem / K_H;

                int ih = oh * Stride_H - Pad_H + kh * Dilation_H;
                int iw = ow * Stride_W - Pad_W + kc * Dilation_W;

                if (ih >= 0 && ih < In_H && iw >= 0 && iw < In_W) {
                    long long input_idx =
                        (long long)batch_idx * (In_C * In_H * In_W) +
                        (long long)ic * (In_H * In_W) +
                        (long long)ih * In_W + iw;
                    val = Input[input_idx];
                }
            }
            As[c * BM + r] = val;
        }

        int num_int4 = (BK * BN / 8) / 4;
        for (int idx = threadIdx.x; idx < num_int4; idx += blockDim.x) {
            int r = idx / ((BN / 8) / 4);
            int c_int4 = idx % ((BN / 8) / 4);
            int c = c_int4 * 4;

            if (bk + r < K) {
                reinterpret_cast<int4*>(&Bs[r * (BN / 8) + c])[0] =
                    reinterpret_cast<int4*>(&B[r * (N / 8) + c])[0];
            } else {
                Bs[r * (BN / 8) + c + 0] = 0;
                Bs[r * (BN / 8) + c + 1] = 0;
                Bs[r * (BN / 8) + c + 2] = 0;
                Bs[r * (BN / 8) + c + 3] = 0;
            }
        }
        __syncthreads();

        B += BK * (N / 8);

        for (int dot = 0; dot < BK; ++dot) {
            for (int i = 0; i < TM; ++i){
                float a_val = As[dot * BM + threadRow * TM + i];
                regM[i] = a_val;
                sumA[i] += a_val; // ★ 핵심 최적화: A의 값을 여기서 단순 누적
            }
            for (int i = 0; i < TN_PHY; ++i) {
                regN[i] = Bs[dot * (BN / 8) + threadCol * TN_PHY + i];
            }

            for (int i = 0; i < TN_PHY; ++i) {
                int packed_val = regN[i];

                for (int subN = 0; subN < 8; ++subN) {
                    int int4_val = (packed_val >> (subN * 4)) & 0xF;
                    float w_val = float(int4_val); // ★ 연산 축소: 4-bit 값을 바로 float로만 변환

                    for (int m = 0; m < TM; ++m) {
                        int resNidx = i * 8 + subN;
                        // ★ 연산 축소: Z빼고 S곱하는 무거운 부동소수점 연산 제거
                        threadResults[m * TN_LOG + resNidx] += regM[m] * w_val;
                    }
                }
            }
        }
        __syncthreads();
    }

    // ★ 커널이 종료되기 직전, 수학의 분배법칙을 활용해 단 1번만 스케일/제로포인트 적용
    for (int resIdxN = 0; resIdxN < TN_LOG; ++resIdxN) {
        int globalN = cCol * BN + threadCol * TN_LOG + resIdxN;
        float s = 1.0f;
        float z = 0.0f;

        // 메모리에서 여기서 불러옴 (글로벌 메모리 병목 회피)
        if (globalN < N) {
            s = scale[globalN];
            z = zero_point[globalN];
        }

        for (int resIdxM = 0; resIdxM < TM; ++resIdxM) {
            // C = (A*W - sum(A)*Z) * S
            threadResults[resIdxM * TN_LOG + resIdxN] =
                (threadResults[resIdxM * TN_LOG + resIdxN] - sumA[resIdxM] * z) * s;
        }
    }

    for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) {
        for (uint resIdxN = 0; resIdxN < TN_LOG; resIdxN += 4) {
            int globalRowC = cRow * BM + threadRow * TM + resIdxM;
            int globalColC = cCol * BN + threadCol * TN_LOG + resIdxN;

            if (globalRowC < M) {
                float4 tmp;
                tmp.x = threadResults[resIdxM * TN_LOG + resIdxN];
                tmp.y = threadResults[resIdxM * TN_LOG + resIdxN + 1];
                tmp.z = threadResults[resIdxM * TN_LOG + resIdxN + 2];
                tmp.w = threadResults[resIdxM * TN_LOG + resIdxN + 3];

                if (globalColC + 4 <= N) {
                    reinterpret_cast<float4 *>(&C[(threadRow * TM + resIdxM) * N + threadCol * TN_LOG + resIdxN])[0] = tmp;
                }
                else {
                    float* ptr = &C[(threadRow * TM + resIdxM) * N + threadCol * TN_LOG + resIdxN];
                    if (globalColC + 0 < N) ptr[0] = tmp.x;
                    if (globalColC + 1 < N) ptr[1] = tmp.y;
                    if (globalColC + 2 < N) ptr[2] = tmp.z;
                    if (globalColC + 3 < N) ptr[3] = tmp.w;
                }
            }
        }
    }
}

void sgemm_int4_cuda(
    torch::Tensor Input,
    torch::Tensor B_packed,
    torch::Tensor C,
    torch::Tensor scale,
    torch::Tensor zero_point,
    int batch, int In_H, int In_W, int In_C,
    int Out_H, int Out_W,
    int K_H, int K_W,
    int Pad_H, int Pad_W,
    int Stride_H, int Stride_W,
    int Dilation_H, int Dilation_W
) {
    int M = batch * Out_H * Out_W;
    int K = In_C * K_H * K_W;
    int N = C.size(1);

    const int BM = 64;
    const int BN = 64;
    const int BK = 8;
    const int TM = 4;
    const int TN_PHY = 1;

    TORCH_CHECK(N % BN == 0, "N must be a multiple of 128");
    TORCH_CHECK(K % 4 == 0, "K must be a multiple of 4 for float4 loading");
    TORCH_CHECK(N % 32 == 0, "N must be a multiple of 32 for int4 loading");

    TORCH_CHECK(Input.is_contiguous(), "Input must be contiguous");
    TORCH_CHECK(B_packed.is_contiguous(), "B_packed must be contiguous");
    TORCH_CHECK(C.is_contiguous(), "C must be contiguous");
    TORCH_CHECK(scale.is_contiguous(), "scale must be contiguous");
    TORCH_CHECK(zero_point.is_contiguous(), "zero_point must be contiguous");

    int num_threads = (BM * BN) / (TM * (TN_PHY * 8));
    dim3 blockDim(num_threads);
    dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM);

    sgemm2D<BM, BN, BK, TM, TN_PHY><<<gridDim, blockDim>>>(
        Input.data_ptr<float>(),
        B_packed.data_ptr<int>(),
        C.data_ptr<float>(),
        scale.data_ptr<float>(),
        zero_point.data_ptr<float>(),
        batch, In_H, In_W, In_C,
        Out_H, Out_W,
        K_H, K_W,
        Pad_H, Pad_W,
        Stride_H, Stride_W,
        Dilation_H, Dilation_W
    );
}
"""

cpp_source = """
void sgemm_int4_cuda(
    torch::Tensor Input,
    torch::Tensor B_packed,
    torch::Tensor C,
    torch::Tensor scale,
    torch::Tensor zero_point,
    int batch, int In_H, int In_W, int In_C,
    int Out_H, int Out_W,
    int K_H, int K_W,
    int Pad_H, int Pad_W,
    int Stride_H, int Stride_W,
    int Dilation_H, int Dilation_W
);
"""

# ---------------------------------------------------------
# 2. Compile (JIT)
# ---------------------------------------------------------
sgemm_module = load_inline(
    name="sgemm_int4_v1",
    cpp_sources=cpp_source,
    cuda_sources=cuda_source, # 주의: 실제로는 위에서 정의한 긴 문자열이 들어가야 합니다.
    functions=['sgemm_int4_cuda'],
    verbose=False,
    with_cuda=True,
    extra_cuda_cflags=["-O3", "-lineinfo"]
)

# ---------------------------------------------------------
# 3. Quantized Layer Definition
# ---------------------------------------------------------
class QuantizedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True,
                 BM=128, BK=8, BN=128):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.padding = padding
        self.dilation = dilation # Dilation 추가
        self.kernel_size = kernel_size
        self.BM = BM
        self.BK = BK
        self.BN = BN

        self.K = self.in_channels * self.kernel_size * self.kernel_size
        self.N = self.out_channels

        self.pad_n = (BN - self.N % BN) % BN
        self.pad_k = (BK - self.K % BK) % BK

        self.register_buffer('w_packed', torch.zeros(self.K + self.pad_k, (self.N + self.pad_n) // 8, dtype=torch.int32))
        self.register_buffer('scale', torch.ones(self.N + self.pad_n, dtype=torch.float32))
        self.register_buffer('zero_point', torch.zeros(self.N + self.pad_n, dtype=torch.float32))

        if bias:
            self.register_buffer('bias', torch.zeros(self.N + self.pad_n, dtype=torch.float32))
        else:
            self.register_parameter('bias', None)

    def forward(self, x):
        # 1. 입력 차원 추출
        batch_size, in_channels, h_in, w_in = x.shape

        # 2. 출력 차원 계산
        h_out = (h_in + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1) // self.stride + 1
        w_out = (w_in + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1) // self.stride + 1

        # GEMM 관점에서의 M 차원 (Total number of output pixels per channel)
        M = batch_size * h_out * w_out
        N_padded = self.w_packed.shape[1] * 8

        assert N_padded % self.BN == 0, "N축 패딩 필요"

        # 3. 출력 텐서 (C) 메모리 할당 (M x N)
        C_padded = torch.empty((M, N_padded), dtype=torch.float32, device=x.device)

        # 메모리가 연속적인지 확인 (NCHW 포맷 그대로 커널에 전달)
        x = x.contiguous()

        # 4. Implicit GEMM 커널 호출 (F.unfold 없이 원본 x 전달)
        # ※ 주의: sgemm_module의 Pybind11 바인딩 함수도 아래 인자들을 받도록 수정되어 있어야 합니다.
        sgemm_module.sgemm_int4_cuda(
            x,               # Input (Batch, In_C, In_H, In_W)
            self.w_packed,   # B (Packed Weights)
            C_padded,        # C (Output Matrix)
            self.scale,
            self.zero_point,
            batch_size,
            h_in, w_in, in_channels,
            h_out, w_out,
            self.kernel_size, self.kernel_size,
            self.padding, self.padding,
            self.stride, self.stride,
            self.dilation, self.dilation
        )

        return C_padded

# ---------------------------------------------------------
# 4. Main Execution Block for Profiling
# ---------------------------------------------------------
def main():
    device = torch.device('cuda')

    # 모델 생성 (ResNet 전체 대신 무거운 레이어 하나만 테스트해도 충분합니다)
    # 실제로는 전체 모델을 로드해서 하셔도 됩니다.
    model = QuantizedConv2d(64, 128, kernel_size=3, stride=1, padding=1).to(device)
    dummy_input = torch.randn(32, 64, 32, 32).to(device)

    # Warm-up
    print("Warm-up...")
    for _ in range(5):
        model(dummy_input)
    torch.cuda.synchronize()

    # Profiling Run
    print("Profiling Start...")

    # NVTX Range로 전체 구간 표시
    nvtx.range_push("My_Model_Inference")

    for i in range(10): # 10번 반복
        nvtx.range_push(f"Iter_{i}")
        model(dummy_input)
        torch.cuda.synchronize() # 정확한 시간 측정을 위해 (배포시는 제거)
        nvtx.range_pop()

    nvtx.range_pop()
    print("Done.")

if __name__ == "__main__":
    main()

Writing profile_run.py


In [4]:
# --trace=cuda,nvtx,osrt  <-- 여기서 osrt 제거
!nsys profile \
  --trace=cuda,nvtx \
  --output=nsys_result_32_im2col_fuse \
  --force-overwrite=true \
  --stats=true \
  python profile_run.py

Warm-up...
Profiling Start...
Done.
Generating '/tmp/nsys-report-a59d.qdstrm'
[3/7] Executing 'nvtx_sum' stats report

 Time (%)  Total Time (ns)  Instances    Avg (ns)      Med (ns)     Min (ns)    Max (ns)   StdDev (ns)   Style         Range       
 --------  ---------------  ---------  ------------  ------------  ----------  ----------  -----------  -------  ------------------
     50.0       51,499,099          1  51,499,099.0  51,499,099.0  51,499,099  51,499,099          0.0  PushPop  My_Model_Inference
      5.1        5,204,104          1   5,204,104.0   5,204,104.0   5,204,104   5,204,104          0.0  PushPop  Iter_0            
      5.0        5,159,819          1   5,159,819.0   5,159,819.0   5,159,819   5,159,819          0.0  PushPop  Iter_1            
      5.0        5,146,267          1   5,146,267.0   5,146,267.0   5,146,267   5,146,267          0.0  PushPop  Iter_5            
      5.0        5,142,368          1   5,142,368.0   5,142,368.0   5,142,368   5,142,368

In [5]:
!./sgemm_run

/bin/bash: line 1: ./sgemm_run: No such file or directory


In [6]:
# Colab 셀에서 실행
!ncu \
  --set full \
  --kernel-name regex:sgemm2D \
  --launch-count 1 \
  -o ncu_result_32_im2col_fuse \
  -f \
  python profile_run.py

==PROF== Connected to process 2318 (/usr/bin/python3.12)
Warm-up...
Profiling Start...
Done.
==PROF== Disconnected from process 2318
Available Kernels:
1. sgemm2D
