<a href="https://colab.research.google.com/github/ved1beta/Triton/blob/main/explain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# First, create a new notebook in Google Colab and paste this code into different cells

# Cell 1 - Install dependencies
%%capture
!pip install triton

# Cell 2 - Import libraries and check GPU
import torch
import triton
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Current GPU: {torch.cuda.get_device_name(0)}")

# Cell 3 - Neural Network Implementation
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

@triton.jit
def triton_nn_kernel(
    input_ptr,
    weight_ptr,
    output_ptr,
    batch_size,
    in_features,
    out_features,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)

    batch_idx = pid // out_features
    out_idx = pid % out_features

    offsets = tl.arange(0, BLOCK_SIZE)

    acc = 0.0

    input_block_ptr = input_ptr + batch_idx * in_features
    weight_block_ptr = weight_ptr + out_idx * in_features

    for block_start in range(0, in_features, BLOCK_SIZE):
        block_mask = block_start + offsets < in_features

        input_block = tl.load(input_block_ptr + block_start + offsets, mask=block_mask, other=0.0)
        weight_block = tl.load(weight_block_ptr + block_start + offsets, mask=block_mask, other=0.0)

        acc += tl.sum(input_block * weight_block * block_mask, axis=0)

    output_offset = batch_idx * out_features + out_idx
    tl.store(output_ptr + output_offset, acc)

class TritonNeuralNetwork(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.layer1 = torch.nn.Linear(in_features, hidden_features)
        self.layer2 = torch.nn.Linear(hidden_features, out_features)
        self.block_size = 32

    def forward(self, x):
        if len(x.shape) > 2:
            x = x.view(x.size(0), -1)

        batch_size = x.shape[0]
        output = torch.empty((batch_size, self.layer1.out_features), device=x.device)

        grid = (batch_size * self.layer1.out_features,)
        triton_nn_kernel[grid](
            x.contiguous(),
            self.layer1.weight.contiguous(),
            output,
            batch_size,
            self.layer1.in_features,
            self.layer1.out_features,
            self.block_size,
        )

        hidden = F.relu(output)
        return self.layer2(hidden)

# Cell 4 - Training Setup and Data Loading
# Set random seed
torch.manual_seed(42)

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('/content/data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('/content/data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize model
model = TritonNeuralNetwork(
    in_features=28*28,
    hidden_features=256,
    out_features=10
).cuda()

# Cell 5 - Training Function
def train_and_evaluate(model, train_loader, test_loader, epochs=10, learning_rate=0.01):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    metrics = {
        'train_loss': [],
        'train_acc': [],
        'test_loss': [],
        'test_acc': []
    }

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        for inputs, targets in train_loader:
            inputs, targets = inputs.cuda(), targets.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        train_acc = 100. * correct / total
        train_loss = train_loss / len(train_loader)

        # Testing
        model.eval()
        test_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        test_acc = 100. * correct / total
        test_loss = test_loss / len(test_loader)

        metrics['train_loss'].append(train_loss)
        metrics['train_acc'].append(train_acc)
        metrics['test_loss'].append(test_loss)
        metrics['test_acc'].append(test_acc)

        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
        print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%')

    return metrics

# Cell 6 - Run Training
metrics = train_and_evaluate(
    model,
    train_loader,
    test_loader,
    epochs=10,
    learning_rate=0.001
)

# Cell 7 - Plot Results
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(metrics['train_loss'], label='Train Loss')
plt.plot(metrics['test_loss'], label='Test Loss')
plt.title('Loss over epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(metrics['train_acc'], label='Train Accuracy')
plt.plot(metrics['test_acc'], label='Test Accuracy')
plt.title('Accuracy over epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.show()