In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, hid_channels, num_res_blocks=2):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hid_channels = hid_channels
        self.num_res_blocks = num_res_blocks

        # Define layers
        self.in_conv = nn.Conv2d(in_channels, hid_channels, kernel_size=3, padding=1)
        self.down = nn.Sequential(
            nn.Conv2d(hid_channels, hid_channels * 2, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.middle = nn.Sequential(
            nn.Conv2d(hid_channels * 2, hid_channels * 2, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.up = nn.Sequential(
            nn.ConvTranspose2d(hid_channels * 2, hid_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU()
        )
        self.out_conv = nn.Conv2d(hid_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.in_conv(x)
        x = self.down(x)
        x = self.middle(x)
        x = self.up(x)
        x = self.out_conv(x)
        return x


class QuantizeWeights(nn.Module):
    def __init__(self, bit_width):
        super(QuantizeWeights, self).__init__()
        self.bit_width = bit_width
        self.interval = nn.Parameter(torch.tensor(1.0), requires_grad=True)

    def initialize_interval(self, w):
        with torch.no_grad():
            max_val = w.max().item()
            min_val = w.min().item()
            self.interval.data = torch.tensor(
                (max_val - min_val) / (2 ** (self.bit_width - 1) - 1),
                device=self.interval.device
            )

    def forward(self, w):
        self.interval.data.clamp_(1e-5, float("inf"))
        max_val = 2 ** (self.bit_width - 1) - 1
        min_val = -2 ** (self.bit_width - 1)
        w_clamped = torch.clamp(w / self.interval, min_val, max_val)
        return torch.round(w_clamped) * self.interval


class QuantizeActivations(nn.Module):
    def __init__(self, bit_width):
        super(QuantizeActivations, self).__init__()
        self.bit_width = bit_width
        self.interval = nn.Parameter(torch.tensor(1.0), requires_grad=True)

    def initialize_interval(self, x):
        with torch.no_grad():
            max_val = x.max().item()
            self.interval.data = torch.tensor(
                max_val / (2 ** self.bit_width - 1),
                device=self.interval.device
            )

    def forward(self, x):
        self.interval.data.clamp_(1e-5, float("inf"))
        max_val = 2 ** self.bit_width - 1
        x_clamped = torch.clamp(x / self.interval, 0, max_val)
        return torch.round(x_clamped) * self.interval


class QuantizedUNet(UNet):
    def __init__(self, in_channels, out_channels, hid_channels, num_res_blocks=2, bit_width=8):
        super(QuantizedUNet, self).__init__(in_channels, out_channels, hid_channels, num_res_blocks)
        self.bit_width = bit_width
        self.weight_quantizer = QuantizeWeights(bit_width)
        self.activation_quantizer = QuantizeActivations(bit_width)
        self.initialized = False

    def initialize_quantization_intervals(self):
        """Initialize weight and activation intervals."""
        if self.initialized:
            return

        # Initialize weight quantization intervals
        for name, param in self.named_parameters():
            if 'weight' in name:
                self.weight_quantizer.initialize_interval(param)

        # Create dummy input on the same device as the model
        device = next(self.parameters()).device
        dummy_input = torch.randn(1, self.in_channels, 64, 64, device=device)

        # Perform a dummy forward pass to initialize activation quantizers
        with torch.no_grad():
            _ = self._dummy_forward(dummy_input)
        self.initialized = True

    def _dummy_forward(self, x):
        """A dummy forward pass for initialization."""
        x = self.in_conv(x)
        x = self.activation_quantizer(x)
        x = self.down(x)
        x = self.middle(x)
        x = self.up(x)
        x = self.out_conv(x)
        return x

    def forward(self, x):
        """Forward pass with quantization."""
        if not self.initialized:
            self.initialize_quantization_intervals()

        # Forward pass with quantized weights and activations
        x = self.in_conv(x)
        x = self.activation_quantizer(x)
        x = self.down(x)
        x = self.middle(x)
        x = self.activation_quantizer(x)
        x = self.up(x)
        x = self.out_conv(x)
        return x

    def load_from_unet(self, unet):
        """Load weights from a regular UNet."""
        self.load_state_dict(unet.state_dict(), strict=False)


In [None]:
!pip install pytorch-fid
from pytorch_fid import fid_score
import os
from PIL import Image


# Function to save images to disk
def save_images_to_disk(images, directory, prefix="image"):
    if not os.path.exists(directory):
        os.makedirs(directory)
    for i, img in enumerate(images):
        img = img.cpu().detach().numpy().transpose(1, 2, 0)  # Convert to numpy
        img = (img * 255).astype('uint8')  # Convert to [0, 255] for image saving
        image_path = os.path.join(directory, f"{prefix}_{i}.png")
        Image.fromarray(img).save(image_path)

# Calculate FID score
def calculate_fid_from_disk(real_images_dir, generated_images_dir):
    # Implement the FID calculation logic here
    # Placeholder function, assume it's already implemented
    return fid_score.calculate_fid_given_paths(
        [generated_images_dir, real_images_dir],
        batch_size=50,
        device="cuda",
        dims=2048
    )




In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm  # Import tqdm for the progress bar

# Define the UNet model (already provided in the question)
# Assuming that the UNet class is already defined.

# Step 1: Define a training function for UNet on CIFAR-10
def train_unet_on_cifar10(model, num_epochs=10, batch_size=64, lr=1e-3, device='cuda'):
    # Dataset and DataLoader
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    print(len(train_loader))

    # Loss function and optimizer
    loss_fn = nn.MSELoss()  # MSE loss
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Get the device (cuda or cpu)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Temporary directory for saving images
    real_images_dir = "./real_images"
    generated_images_dir = "./generated_images"

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        # Training loop
        with tqdm(train_loader, unit="batch", desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for images, labels in pbar:
                images = images.to(device)  # Move images to the same device as the model

                # Forward pass
                outputs = model(images)
                loss = loss_fn(outputs, images)  # MSE loss between output and input image

                # Backward pass and optimization
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

                # Update tqdm progress bar with loss
                pbar.set_postfix(loss=running_loss / (pbar.n + 1))

        if epoch % 5 == 4:
            # Save images and calculate FID score at the end of each epoch
            save_images_to_disk(images, real_images_dir, prefix=f"real_epoch_{epoch+1}")
            save_images_to_disk(outputs, generated_images_dir, prefix=f"gen_epoch_{epoch+1}")

            # Calculate FID score for the epoch
            fid_score = calculate_fid_from_disk(real_images_dir, generated_images_dir)
            print(f"FID Score at Epoch {epoch+1}: {fid_score}")

        # Print the loss for every epoch
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader)}")

# Initialize UNet model and train
unet = UNet(in_channels=3, out_channels=3, hid_channels=64).to('cuda')
train_unet_on_cifar10(unet, num_epochs=10)


Files already downloaded and verified
782


Epoch 1/10: 100%|██████████| 782/782 [00:36<00:00, 21.49batch/s, loss=0.00575]


Epoch [1/10], Loss: 0.005746270485600168


Epoch 2/10: 100%|██████████| 782/782 [00:29<00:00, 26.91batch/s, loss=0.000556]


Epoch [2/10], Loss: 0.0005560295879142656


Epoch 3/10: 100%|██████████| 782/782 [00:28<00:00, 27.61batch/s, loss=0.000384]


Epoch [3/10], Loss: 0.0003830612464377697


Epoch 4/10: 100%|██████████| 782/782 [00:29<00:00, 26.57batch/s, loss=0.000274]


Epoch [4/10], Loss: 0.0002733420763659598


Epoch 5/10: 100%|██████████| 782/782 [00:28<00:00, 27.67batch/s, loss=0.000248]




100%|██████████| 1/1 [00:00<00:00,  7.20it/s]




100%|██████████| 1/1 [00:00<00:00,  7.06it/s]


FID Score at Epoch 5: 44.45087886570661
Epoch [5/10], Loss: 0.00024755763807310187


Epoch 6/10: 100%|██████████| 782/782 [00:28<00:00, 27.92batch/s, loss=0.000205]


Epoch [6/10], Loss: 0.00020403270985384214


Epoch 7/10: 100%|██████████| 782/782 [00:28<00:00, 27.69batch/s, loss=0.000148]


Epoch [7/10], Loss: 0.00014801885790011906


Epoch 8/10: 100%|██████████| 782/782 [00:29<00:00, 26.93batch/s, loss=0.000129]


Epoch [8/10], Loss: 0.00012866361313728023


Epoch 9/10: 100%|██████████| 782/782 [00:29<00:00, 26.78batch/s, loss=0.000133]


Epoch [9/10], Loss: 0.00013255421255306622


Epoch 10/10: 100%|██████████| 782/782 [00:28<00:00, 27.09batch/s, loss=0.000121]




100%|██████████| 1/1 [00:00<00:00,  4.09it/s]




100%|██████████| 1/1 [00:00<00:00,  4.72it/s]


FID Score at Epoch 10: 36.14328189989487
Epoch [10/10], Loss: 0.0001212307139291585


In [None]:
# Step 2: Create a student model (QuantizedUNet)
# Training function for the quantized student model (QuantizedUNet)
def train_quantized_unet_from_baseline(unet, num_epochs=5, batch_size=64, lr=1e-3, device='cuda'):
    # Initialize student model (QuantizedUNet) with UNet weights
    student_model = QuantizedUNet(in_channels=3, out_channels=3, hid_channels=64, bit_width=8).to(device)
    student_model.load_from_unet(unet)

    # Dataset and DataLoader
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Loss function and optimizer
    loss_fn = nn.MSELoss()  # MSE loss
    optimizer = optim.Adam(student_model.parameters(), lr=lr)

    # Temporary directory for saving images
    real_images_dir = "./real_images"
    generated_images_dir = "./generated_images"

    # Train for 5 epochs
    for epoch in range(num_epochs):
        student_model.train()
        running_loss = 0.0

        # Training loop with tqdm progress bar
        with tqdm(train_loader, unit="batch", desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for images, labels in pbar:
                images = images.to(device)  # Move images to the same device as the model

                # Forward pass
                outputs = student_model(images)
                loss = loss_fn(outputs, images)  # MSE loss

                # Backward pass and optimization
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

                # Update tqdm progress bar with loss
                pbar.set_postfix(loss=running_loss / (pbar.n + 1))

        # Save images and calculate FID score at the end of each epoch
        if (epoch + 1) % 3 == 0:  # Save and calculate FID every epoch
            save_images_to_disk(images, real_images_dir, prefix=f"real_epoch_{epoch+1}")
            save_images_to_disk(outputs, generated_images_dir, prefix=f"gen_epoch_{epoch+1}")

            # Calculate FID score for the epoch
            fid_score = calculate_fid_from_disk(real_images_dir, generated_images_dir)
            print(f"FID Score at Epoch {epoch+1}: {fid_score}")

        # Print the loss for every epoch
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader)}")

    # Return the trained student model
    return student_model

student_model = train_quantized_unet_from_baseline(unet, num_epochs=3)



Files already downloaded and verified


Epoch 1/3: 100%|██████████| 782/782 [00:28<00:00, 27.80batch/s, loss=0.131]


Epoch [1/3], Loss: 0.1310089953300898


Epoch 2/3: 100%|██████████| 782/782 [00:28<00:00, 27.41batch/s, loss=0.154]


Epoch [2/3], Loss: 0.1536112810625597


Epoch 3/3: 100%|██████████| 782/782 [00:28<00:00, 27.60batch/s, loss=0.164]




100%|██████████| 1/1 [00:00<00:00,  2.94it/s]




100%|██████████| 1/1 [00:00<00:00,  3.44it/s]


FID Score at Epoch 3: 112.00976041198737
Epoch [3/3], Loss: 0.1635710603326483


In [None]:
# Define co-studying loss function
def co_studying_loss(x, teacher_logits, student_logits, temperature=1.0):
    import torch.nn.functional as F

    # Softened probabilities
    teacher_soft = F.softmax(teacher_logits / temperature, dim=1)
    student_soft = F.softmax(student_logits / temperature, dim=1)

    # MSE loss for both networks
    teacher_mse_loss = F.mse_loss(teacher_logits, x)
    student_mse_loss = F.mse_loss(student_logits, x)

    # KL divergence loss (softened)
    kl_loss_student = F.kl_div(student_soft.log(), teacher_soft, reduction='batchmean')
    kl_loss_teacher = F.kl_div(teacher_soft.log(), student_soft, reduction='batchmean')

    # Combine losses
    student_loss = student_mse_loss + (temperature ** 2) * kl_loss_student
    teacher_loss = teacher_mse_loss + (temperature ** 2) * kl_loss_teacher

    return teacher_loss, student_loss

# Step 3: Train a new teacher model with QuantizedUNet student model
def train_teacher_with_student(teacher_model, student_model, num_epochs=20, batch_size=64, lr=1e-3, device='cuda'):
    # Dataset and DataLoader
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Loss function and optimizers
    optimizer_teacher = optim.Adam(teacher_model.parameters(), lr=lr)
    optimizer_student = optim.Adam(student_model.parameters(), lr=lr)

    # Directories for saving images
    real_images_dir = "./real_images"
    generated_images_dir_teacher = "./generated_images_teacher"
    generated_images_dir_student = "./generated_images_student"

    # Train for 20 epochs
    for epoch in range(num_epochs):
        teacher_model.train()
        student_model.train()
        running_teacher_loss = 0.0
        running_student_loss = 0.0

        # Training loop with tqdm progress bar
        with tqdm(train_loader, unit="batch", desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for images, labels in pbar:
                images = images.to(device)  # Move images to the same device as the models

                # Forward pass for teacher and student
                teacher_outputs = teacher_model(images)
                student_outputs = student_model(images)

                # Compute co-studying loss
                teacher_loss, student_loss = co_studying_loss(images, teacher_outputs, student_outputs)

                # Backward pass and optimization
                optimizer_teacher.zero_grad()
                optimizer_student.zero_grad()

                # Teacher backward pass (retain graph to keep it for student backward pass)
                teacher_loss.backward(retain_graph=True)

                # Student backward pass
                student_loss.backward()

                optimizer_teacher.step()
                optimizer_student.step()

                running_teacher_loss += teacher_loss.item()
                running_student_loss += student_loss.item()

                # Update tqdm progress bar with loss
                pbar.set_postfix(
                    teacher_loss=running_teacher_loss / (pbar.n + 1),
                    student_loss=running_student_loss / (pbar.n + 1)
                )

        # Print the loss for every epoch
        print(f"Epoch [{epoch+1}/{num_epochs}], Teacher Loss: {running_teacher_loss / len(train_loader)}, Student Loss: {running_student_loss / len(train_loader)}")

        # Save images and calculate FID score at the end of each epoch
        if (epoch + 1) % 7 == 0:  # Save and calculate FID every 5 epochs
            save_images_to_disk(images, real_images_dir, prefix=f"real_epoch_{epoch+1}")
            save_images_to_disk(teacher_outputs, generated_images_dir_teacher, prefix=f"gen_teacher_epoch_{epoch+1}")
            save_images_to_disk(student_outputs, generated_images_dir_student, prefix=f"gen_student_epoch_{epoch+1}")

            # Calculate FID score for teacher model
            fid_score_teacher = calculate_fid_from_disk(real_images_dir, generated_images_dir_teacher)
            print(f"FID Score for Teacher at Epoch {epoch+1}: {fid_score_teacher}")

            # Calculate FID score for student model
            fid_score_student = calculate_fid_from_disk(real_images_dir, generated_images_dir_student)
            print(f"FID Score for Student at Epoch {epoch+1}: {fid_score_student}")

teacher_model = UNet(in_channels=3, out_channels=3, hid_channels=64).to('cuda')
train_teacher_with_student(teacher_model, student_model, num_epochs=7, batch_size=64, lr=1e-3, device='cuda')


Files already downloaded and verified


Epoch 1/7: 100%|██████████| 782/782 [00:45<00:00, 17.29batch/s, student_loss=0.16, teacher_loss=0.0706]


Epoch [1/7], Teacher Loss: 0.07050962042292137, Student Loss: 0.1601168326176036


Epoch 2/7: 100%|██████████| 782/782 [00:44<00:00, 17.44batch/s, student_loss=0.132, teacher_loss=0.0329]


Epoch [2/7], Teacher Loss: 0.0328873094800107, Student Loss: 0.13147577581465092


Epoch 3/7: 100%|██████████| 782/782 [00:45<00:00, 17.19batch/s, student_loss=0.123, teacher_loss=0.0281]


Epoch [3/7], Teacher Loss: 0.028065757728312785, Student Loss: 0.12249921176515882


Epoch 4/7: 100%|██████████| 782/782 [00:45<00:00, 17.11batch/s, student_loss=0.12, teacher_loss=0.0272]


Epoch [4/7], Teacher Loss: 0.02717136810688526, Student Loss: 0.12013535706512153


Epoch 5/7: 100%|██████████| 782/782 [00:44<00:00, 17.39batch/s, student_loss=0.12, teacher_loss=0.025]


Epoch [5/7], Teacher Loss: 0.02495403000441811, Student Loss: 0.11984985988691944


Epoch 6/7: 100%|██████████| 782/782 [00:44<00:00, 17.44batch/s, student_loss=0.121, teacher_loss=0.0247]


Epoch [6/7], Teacher Loss: 0.02462351443531835, Student Loss: 0.12103266314701046


Epoch 7/7: 100%|██████████| 782/782 [00:44<00:00, 17.39batch/s, student_loss=0.123, teacher_loss=0.0242]


Epoch [7/7], Teacher Loss: 0.02419323412120304, Student Loss: 0.12291750722490918


100%|██████████| 1/1 [00:00<00:00,  5.85it/s]
100%|██████████| 2/2 [00:00<00:00,  4.87it/s]


FID Score for Teacher at Epoch 7: 295.2592911249533


100%|██████████| 1/1 [00:00<00:00,  6.26it/s]
100%|██████████| 2/2 [00:00<00:00,  5.16it/s]


FID Score for Student at Epoch 7: 329.0478384838251
