# Deploying LLMs: Strategies and Considerations

## Outline:
- Techniques for efficient deployment of LLMs in production environments.
- **Handds-on Lab:** Quantization, pruning, and distillation techniques.

---
### Overview of Deployment Challenges
- High computational requirements
- Latency issues
- Scalability concerns
- Security and compliance
---
### Model Selection
- Choose an appropriate model (e.g., GPT-3, GPT-4, BERT) based on your needs
- Consider the trade-offs between model complexity and performance

### Resource Allocation
- Estimate the computational resources required for training and inference
- Plan for infrastructure costs and scalability
---
## Optimization Techniques
- **Quantization:** Reduces model size and increases speed by lowering precision.
- **Pruning:** Removes unimportant connections to create a sparser, faster model.
- **Distillation:** Trains a smaller model to replicate a larger model's behavior, capturing its knowledge efficiently.


---
## Deployment Strategies

### Batch Processing vs. Real-Time Inference
- **Batch Processing:** Suitable for non-time-sensitive tasks, efficient resource utilization
- **Real-Time Inference:** Essential for applications requiring immediate responses
---



## **Optimization Techniques for LLMs**
---
### Quantization


- Reducing the precision of model weights to decrease memory usage and increase inference speed
- Common techniques: 8-bit integer quantization, mixed precision training
- Types of Quantization

  - **Post-Training Quantization:** Applied after training the model, converting weights from floating-point precision (e.g., FP32) to lower precision (e.g., INT8).
  - **Quantization-Aware Training:** The model is trained with quantization in mind, simulating low-precision calculations during training to better adjust the weights.

- [TensorFlow Model Optimization Toolkit](https://www.tensorflow.org/model_optimization)
- [Post-Training Quantization](https://www.tensorflow.org/model_optimization/guide/quantization/post_training)



### Pruning

- Removing less important neurons or layers from the model to reduce its size and complexity and improve inference speed without significantly compromising accuracy.
- Methods: Magnitude pruning, structured pruning
- Types of Pruning

  - **Magnitude-Based Pruning:** Removes weights with the smallest magnitudes.
  - **Structured Pruning:** Removes entire neurons, channels, or layers.


- [TensorFlow Model Optimization: Pruning](https://www.tensorflow.org/model_optimization/guide/pruning)
- [Neural Network Pruning: A Survey](https://arxiv.org/abs/1710.01878)
---

### Knowledge Distillation

- Training a smaller model (student) to replicate the performance of a larger pre-trained model (teacher)
- Benefits: Reduced computational requirements, faster inference
- The student model learns to mimic the teacher's output, effectively capturing the knowledge in a more compact form.

- [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531)
- [DistilBERT, a distilled version of BERT](https://arxiv.org/abs/1910.01108)
---

# **Lab:** Model Distillation

- Perform knowledge distillation on a simple neural network for the MNIST dataset
- How small can be the Student model before you start seeing performance deterioration?




In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import copy


In [None]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 512)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = SimpleNN()


In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)


In [None]:
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx*len(data)}/{len(train_loader.dataset)} ({100.*batch_idx/len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(1, 6):
    train(model, device, train_loader, optimizer, criterion, epoch)




In [None]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')
    return accuracy

print("Original Model Performance:")
original_accuracy = test(model, device, test_loader)


Original Model Performance:

Test set: Average loss: 0.0001, Accuracy: 9777/10000 (97.77%)



In [None]:
list(model.parameters())[0]

Parameter containing:
tensor([[-0.0058,  0.0423,  0.0571,  ...,  0.0131,  0.0508,  0.0180],
        [ 0.0229,  0.0404, -0.0243,  ..., -0.0061, -0.0089,  0.0215],
        [ 0.0150,  0.0047,  0.0189,  ...,  0.0017, -0.0199,  0.0299],
        ...,
        [ 0.0040, -0.0094, -0.0156,  ...,  0.0445,  0.0523,  0.0082],
        [ 0.0269,  0.0020,  0.0441,  ...,  0.0287,  0.0293,  0.0522],
        [ 0.0360,  0.0171,  0.0128,  ..., -0.0170,  0.0104,  0.0284]],
       requires_grad=True)

In [None]:
#distillation

class SmallNN(nn.Module):
    def __init__(self):
        super(SmallNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

student_model = SmallNN().to(device)

def distillation_loss(student_outputs, teacher_outputs, labels, T, alpha):
    soft_targets = nn.functional.softmax(teacher_outputs / T, dim=1)
    student_loss = nn.functional.cross_entropy(student_outputs, labels)
    distillation_loss = nn.functional.kl_div(nn.functional.log_softmax(student_outputs / T, dim=1), soft_targets, reduction='batchmean') * (T * T)
    return alpha * student_loss + (1 - alpha) * distillation_loss

teacher_model = model
teacher_model.eval()
student_optimizer = optim.Adam(student_model.parameters(), lr=0.001)

for epoch in range(1, 6):
    student_model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        student_optimizer.zero_grad()
        student_output = student_model(data)
        teacher_output = teacher_model(data).detach()
        loss = distillation_loss(student_output, teacher_output, target, T=2.0, alpha=0.5)
        loss.backward()
        student_optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx*len(data)}/{len(train_loader.dataset)} ({100.*batch_idx/len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

print("Distilled Student Model Performance:")
distilled_accuracy = test(student_model, device, test_loader)


Distilled Student Model Performance:

Test set: Average loss: 0.0001, Accuracy: 9754/10000 (97.54%)



In [None]:
list(teacher_model.parameters())[0].shape

torch.Size([512, 784])

In [None]:
list(student_model.parameters())[0].shape

torch.Size([128, 784])

In [None]:
list(teacher_model.parameters())[1].shape

torch.Size([512])

In [None]:
list(student_model.parameters())[1].shape

torch.Size([128])