In [None]:
import tensorflow as tf

device_name = tf.test.gpu_device_name()

if "GPU" not in device_name:
    print("GPU device not found")
    
print('Found GPU at: {}'.format(device_name))

print("GPU", "available (YESS!!!!)" if tf.config.list_physical_devices("GPU") else "not available :(")

# Imports 

In [4]:
import os
import json

import matplotlib
import matplotlib.pyplot as plt
import matplotlib_inline.backend_inline
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR100
import math
from prettytable import PrettyTable

# VIT_CIFAR100.py 

In [29]:
# Literally the only thing we will change is now there will be 
# two layers in FFN. 


class PatchEmbeddings(nn.Module):
    """
    parameters: Image_size, patch_size, num_channels, embed_dim
    """

    def __init__(self, config):
        super().__init__()
        self.image_size = config["image_size"]
        self.patch_size = config["patch_size"]
        self.num_channels = config["num_channels"]
        self.embed = config["embed_dim"]

        # no. of patches from the image size and patch size
        self.num_patches = (self.image_size // self.patch_size) ** 2 # from (HW)/P^2

        # Create a projection layer to convert the image into patches
        # The layer projects each patch into a vector of size embed_dim
        self.projection = nn.Conv2d(self.num_channels, self.embed, kernel_size=self.patch_size, stride=self.patch_size)

    def forward(self, x):
        # (batch_size, num_channels, image_size, image_size) -> (batch_size, num_patches, embed_dim)
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class Embeddings(nn.Module):
    """
    Combine the patch embeddings with the class token and position embeddings.
    """

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.patch_embeddings = PatchEmbeddings(config)

        # Creating a learnable [CLS] token


        # the nn.parameter creates a learnable tensor updated by backpropogation
        # we create a random vector from normal distribution along (1, 1, 48)

        self.cls_token = nn.Parameter(torch.randn(1, 1, config["embed_dim"]))

        # Creating learnable position embeddings for the [CLS] token and patches
        # Add 1 to the sequence length for the [CLS] token

        self.position_embeddings = \
            nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches + 1, config["embed_dim"]))

        self.dropout = nn.Dropout(config["dropout_val"])

    def forward(self, x):

        x = self.patch_embeddings(x)
        batch_size, _, _ = x.size()

        # Expand the [CLS] token to the batch size
        # (1, 1, embed_dim) -> (batch_size, 1, embed_dim)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)

        # Concatenate the [CLS] token to the beginning of the input sequence
        # This results in a sequence length of (num_patches + 1)

        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.position_embeddings
        x = self.dropout(x)
        return x
    
class AttentionHead(nn.Module):
    """
    A single attention head.
    This module is used in the MultiHeadAttention module.
    """
    def __init__(self, embed_dim, attention_head_dim, dropout, bias=True):
        super().__init__()
        self.embed = embed_dim
        self.attention_head_dim = attention_head_dim
        # Create the query, key, and value projection layers
        self.query = nn.Linear(embed_dim, attention_head_dim, bias=bias)
        self.key = nn.Linear(embed_dim, attention_head_dim, bias=bias)
        self.value = nn.Linear(embed_dim, attention_head_dim, bias=bias)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Project the input into query, key, and value
        # The same input is used to generate the query, key, and value,
        # so it's usually called self-attention.

        # (batch_size, sequence_length, embed_dim)
        # transforms to
        # -> (batch_size, sequence_length, attention_head_dim)


        query = self.query(x)
        key = self.key(x)
        value = self.value(x)


        # Calculate the attention scores
        # softmax(Q*K.T/sqrt(head_size))*V

        attention_scores = torch.matmul(query, key.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_dim)
        #we will now return attention_scores and apply MSE on those. 
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)

        # Calculate the attention output
        attention_output = torch.matmul(attention_probs, value)

        return (attention_output, attention_probs)


class MultiHeadAttention(nn.Module):
    """
    Multi-head attention module.
    This module is used in the TransformerEncoder module.
    """

    def __init__(self, config):
        super().__init__()
        self.embed = config["embed_dim"]
        self.num_attention_heads = config["num_attention_heads"]
        # The attention head size is the hidden size divided by the number of attention heads
        self.attention_head_dim = self.embed // self.num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_dim

        # Whether or not to use bias in the query, key, and value projection layers
        self.qkv_bias = config["qkv_bias"]

        # Create a list of attention heads

        self.heads = nn.ModuleList([])
        for _ in range(self.num_attention_heads):
            head = AttentionHead(
                self.embed,
                self.attention_head_dim,
                config["attention_probs_dropout_prob"],
                self.qkv_bias
            )
            self.heads.append(head)

        # Create a linear layer to project the attention output back to the hidden size
        # In most cases, all_head_size and embed_dim are the same
        self.output_projection = nn.Linear(self.all_head_size, self.embed)
        self.output_dropout = nn.Dropout(config["dropout_val"])

    def forward(self, x, output_attentions=False):
        # Calculate the attention output for each attention head
        attention_outputs = [head(x) for head in self.heads]
        # Concatenate the attention outputs from each attention head
        attention_output = torch.cat([attention_output for attention_output, _ in attention_outputs], dim=-1)
        # Project the concatenated attention output back to the hidden size
        attention_output = self.output_projection(attention_output)
        attention_output = self.output_dropout(attention_output)
        # Return the attention output and the attention probabilities (optional)
        if not output_attentions:
            return (attention_output, None)
        else:
            attention_probs = torch.stack([attention_probs for _, attention_probs in attention_outputs], dim=1)
            return (attention_output, attention_probs)

class MLP(nn.Module):
    """
    A multi-layer perceptron module.
    """

    def __init__(self, config):
        super().__init__()
        self.dense_1 = nn.Linear(config["embed_dim"], config["hidden_dim"])
        self.activation = nn.GELU()
        self.dense_intermediate = nn.Linear(config["hidden_dim"], config["hidden_dim"]//2)
        self.dense_2 = nn.Linear(config["hidden_dim"]//2, config["embed_dim"])
        self.dropout = nn.Dropout(config["dropout_val"])

    def forward(self, x):
        x = self.dense_1(x)
        x = self.activation(x)
        x = self.dense_intermediate(x)
        x = self.activation(x)
        x = self.dense_2(x)
        x = self.dropout(x)
        return x
    
class Block(nn.Module):
    """
    A single transformer block.
    """

    def __init__(self, config):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.layernorm_1 = nn.LayerNorm(config["embed_dim"])
        self.mlp = MLP(config)
        self.layernorm_2 = nn.LayerNorm(config["embed_dim"])

    def forward(self, x, output_attentions=False):
        # Self-attention
        attention_output, attention_probs = \
            self.attention(self.layernorm_1(x), output_attentions=output_attentions)
        # Skip connection
        x = x + attention_output
        # Feed-forward network
        mlp_output = self.mlp(self.layernorm_2(x))
        # Skip connection
        x = x + mlp_output
        # Return the transformer block's output and the attention probabilities (optional)
        if not output_attentions:
            return (x, None)
        else:
            return (x, attention_probs)
        
class Encoder(nn.Module):
    """
    The transformer encoder module.
    """

    def __init__(self, config):
        super().__init__()
        # Create a list of transformer blocks
        self.blocks = nn.ModuleList([])
        for _ in range(config["num_hidden_layers"]):
            block = Block(config)
            self.blocks.append(block)

    def forward(self, x, output_attentions=False):
        # Calculate the transformer block's output for each block
        all_attentions = []
        for block in self.blocks:
            x, attention_probs = block(x, output_attentions=output_attentions)
            if output_attentions:
                all_attentions.append(attention_probs)
        # Return the encoder's output and the attention probabilities (optional)
        if not output_attentions:
            return (x, None)
        else:
            return (x, all_attentions)
        
class ViTForClassification(nn.Module):
    """
    The ViT model for classification.
    """

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.image_size = config["image_size"]
        self.embed = config["embed_dim"]
        self.num_classes = config["num_classes"]
        # Create the embedding module
        self.embedding = Embeddings(config)
        # Create the transformer encoder module
        self.encoder = Encoder(config)
        # Create a linear layer to project the encoder's output to the number of classes
        self.classifier = nn.Linear(self.embed, self.num_classes)
        # Initialize the weights
        self.apply(self._init_weights)

    def forward(self, x, output_attentions=False):
        # Calculate the embedding output
        embedding_output = self.embedding(x)
        # Calculate the encoder's output
        encoder_output, all_attentions = self.encoder(embedding_output, output_attentions=output_attentions)
        # Calculate the logits, take the [CLS] token's output as features for classification
        logits = self.classifier(encoder_output[:, 0])
        # Return the logits and the attention probabilities (optional)
        if not output_attentions:
            return (logits, None)
        else:
            return (logits, all_attentions)
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.config["initializer_range"])
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, Embeddings):
            module.position_embeddings.data = nn.init.trunc_normal_(
                module.position_embeddings.data.to(torch.float32),
                mean=0.0,
                std=self.config["initializer_range"],
            ).to(module.position_embeddings.dtype)

            module.cls_token.data = nn.init.trunc_normal_(
                module.cls_token.data.to(torch.float32),
                mean=0.0,
                std=self.config["initializer_range"],
            ).to(module.cls_token.dtype)


# Utilities 

In [30]:
def save_experiment(experiment_name, config, model, train_losses, test_losses, accuracies, base_dir="experiments"):
    outdir = os.path.join(base_dir, experiment_name)
    os.makedirs(outdir, exist_ok=True)

    # Save the config
    configfile = os.path.join(outdir, 'config.json')
    with open(configfile, 'w') as f:
        json.dump(config, f, sort_keys=True, indent=4)

    # Save the metrics
    jsonfile = os.path.join(outdir, 'metrics.json')
    with open(jsonfile, 'w') as f:
        data = {
            'train_losses': train_losses,
            'test_losses': test_losses,
            'accuracies': accuracies,
        }
        json.dump(data, f, sort_keys=True, indent=4)

    # Save the model
    save_checkpoint(experiment_name, model, "final", base_dir=base_dir)


def save_checkpoint(experiment_name, model, epoch, base_dir="experiments"):
    outdir = os.path.join(base_dir, experiment_name)
    os.makedirs(outdir, exist_ok=True)
    cpfile = os.path.join(outdir, f'model_{epoch}.pt')
    torch.save(model.state_dict(), cpfile)


def load_experiment(experiment_name, checkpoint_name="model_final.pt", base_dir="experiments"):
    outdir = os.path.join(base_dir, experiment_name)
    # Load the config
    configfile = os.path.join(outdir, 'config.json')
    with open(configfile, 'r') as f:
        config = json.load(f)
    # Load the metrics
    jsonfile = os.path.join(outdir, 'metrics.json')
    with open(jsonfile, 'r') as f:
        data = json.load(f)
    train_losses = data['train_losses']
    test_losses = data['test_losses']
    accuracies = data['accuracies']
    # Load the model
    model = ViTForClassfication(config)
    cpfile = os.path.join(outdir, checkpoint_name)
    model.load_state_dict(torch.load(cpfile))
    return config, model, train_losses, test_losses, accuracies


def visualize_images():
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True)
    classes = ('plane', 'car', 'bird', 'cat',
            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    # Pick 30 samples randomly
    indices = torch.randperm(len(trainset))[:30]
    images = [np.asarray(trainset[i][0]) for i in indices]
    labels = [trainset[i][1] for i in indices]
    # Visualize the images using matplotlib
    fig = plt.figure(figsize=(10, 10))
    for i in range(30):
        ax = fig.add_subplot(6, 5, i+1, xticks=[], yticks=[])
        ax.imshow(images[i])
        ax.set_title(classes[labels[i]])


@torch.no_grad()
def visualize_attention(model, output=None, device="cuda"):
    """
    Visualize the attention maps of the first 4 images.
    """
    model.eval()
    # Load random images
    num_images = 30
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)
    classes = ('plane', 'car', 'bird', 'cat',
            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    # Pick 30 samples randomly
    indices = torch.randperm(len(testset))[:num_images]
    raw_images = [np.asarray(testset[i][0]) for i in indices]
    labels = [testset[i][1] for i in indices]
    # Convert the images to tensors
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    images = torch.stack([test_transform(image) for image in raw_images])
    # Move the images to the device
    images = images.to(device)
    model = model.to(device)
    # Get the attention maps from the last block
    logits, attention_maps = model(images, output_attentions=True)
    # Get the predictions
    predictions = torch.argmax(logits, dim=1)
    # Concatenate the attention maps from all blocks
    attention_maps = torch.cat(attention_maps, dim=1)
    # select only the attention maps of the CLS token
    attention_maps = attention_maps[:, :, 0, 1:]
    # Then average the attention maps of the CLS token over all the heads
    attention_maps = attention_maps.mean(dim=1)
    # Reshape the attention maps to a square
    num_patches = attention_maps.size(-1)
    size = int(math.sqrt(num_patches))
    attention_maps = attention_maps.view(-1, size, size)
    # Resize the map to the size of the image
    attention_maps = attention_maps.unsqueeze(1)
    attention_maps = F.interpolate(attention_maps, size=(32, 32), mode='bilinear', align_corners=False)
    attention_maps = attention_maps.squeeze(1)
    # Plot the images and the attention maps
    fig = plt.figure(figsize=(20, 10))
    mask = np.concatenate([np.ones((32, 32)), np.zeros((32, 32))], axis=1)
    for i in range(num_images):
        ax = fig.add_subplot(6, 5, i+1, xticks=[], yticks=[])
        img = np.concatenate((raw_images[i], raw_images[i]), axis=1)
        ax.imshow(img)
        # Mask out the attention map of the left image
        extended_attention_map = np.concatenate((np.zeros((32, 32)), attention_maps[i].cpu()), axis=1)
        extended_attention_map = np.ma.masked_where(mask==1, extended_attention_map)
        ax.imshow(extended_attention_map, alpha=0.5, cmap='jet')
        # Show the ground truth and the prediction
        gt = classes[labels[i]]
        pred = classes[predictions[i]]
        ax.set_title(f"gt: {gt} / pred: {pred}", color=("green" if gt==pred else "red"))
    if output is not None:
        plt.savefig(output)
    plt.show()





class Modified_Trainer:
    def __init__(self, model, optimizer, loss_fn, exp_name, device,config, base_dir="experiments"):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.exp_name = exp_name
        self.device = device
        self.base_dir = base_dir
        self.config = config 

    def train(self, trainloader, testloader, epochs,save_model_every_n_epochs, output_attentions=False):
        train_losses, test_losses, accuracies = [], [], []
        all_epoch_attentions = []  # Collect attention from all epochs

        for i in range(epochs):
            if output_attentions:
                train_loss, epoch_attentions = self.train_epoch(trainloader, output_attentions=True)
                all_epoch_attentions.append(epoch_attentions)  # Store attention data
            else:
                train_loss = self.train_epoch(trainloader)

            accuracy, test_loss = self.evaluate(testloader)
            train_losses.append(train_loss)
            test_losses.append(test_loss)
            accuracies.append(accuracy)
            print(f"Epoch: {i+1}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}")

            if save_model_every_n_epochs > 0 and (i+1) % save_model_every_n_epochs == 0:
                save_checkpoint(self.exp_name, self.model, f"epoch_{i+1}", base_dir=self.base_dir)
                print(f'\tSaving checkpoint at epoch {i+1}')

        save_experiment(self.exp_name, self.config , self.model, train_losses, test_losses, accuracies, base_dir=self.base_dir)
        print(f'Final model and experiment details saved under {self.exp_name}')

        return all_epoch_attentions if output_attentions else None

    def train_epoch(self, trainloader, output_attentions=False):
        self.model.train()
        total_loss = 0
        block_attentions = []  # to store attention probabilities if needed

        for _, (images, labels) in enumerate(trainloader):
            images, labels = images.to(self.device), labels.to(self.device)
            self.optimizer.zero_grad()

            # Check if attention probabilities are needed
            if output_attentions:
                logits, attention_probs = self.model(images, output_attentions=True)
                block_attentions.append(attention_probs)
            else:
                logits, _ = self.model(images)

            loss = self.loss_fn(logits, labels)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item() * len(images)

        average_loss = total_loss / len(trainloader.dataset)

        if output_attentions:
            return average_loss, block_attentions  # Optionally return attention data
        return average_loss

    def evaluate(self, testloader):
        self.model.eval()
        total_loss, correct = 0, 0
        for _, (images, labels) in enumerate(testloader):
            images, labels = images.to(self.device), labels.to(self.device)
            logits, _ = self.model(images)
            loss = self.loss_fn(logits, labels)
            total_loss += loss.item() * len(images)
            predictions = torch.argmax(logits, dim=1)
            correct += torch.sum(predictions == labels).item()
        accuracy = correct / len(testloader.dataset)
        avg_loss = total_loss / len(testloader.dataset)
        return accuracy, avg_loss

# Modified Prepare Data

In [31]:
def modified_prepare_data(batch_size=4, num_workers=2, train_sample_size=None, test_sample_size=None):
    # TRAINING TRANSFORMATION
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32),antialias=True),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    trainset = torchvision.datasets.CIFAR100('./', train=True, download=True, transform =train_transform)



    if train_sample_size is not None:
        # Randomly sample a subset of the training set
        indices = torch.randperm(len(trainset))[:train_sample_size]
        trainset = torch.utils.data.Subset(trainset, indices)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=num_workers)

    # TEST TRANSFORMATIONS
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32),antialias=True),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    testset = torchvision.datasets.CIFAR100('./', train=False, download=True, transform =test_transform)



    if test_sample_size is not None:
        # Randomly sample a subset of the test set
        indices = torch.randperm(len(testset))[:test_sample_size]
        testset = torch.utils.data.Subset(testset, indices)

    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=num_workers)

    # Update classes for CIFAR-100
    classes = tuple(f"class_{i+1}" for i in range(100))  # Placeholder for actual class names, update as needed

    return trainloader, testloader, classes

# TS-Student

In [40]:
class TS_Trainer:
    def __init__(self, teacher_model, student_model, optimizer_teacher, optimizer_student, loss_fn, exp_name, device,alpha, base_dir="experiments"):
        self.teacher_model = teacher_model.to(device)
        self.student_model = student_model.to(device)

        self.optimizer_teacher = optimizer_teacher
        self.optimizer_student = optimizer_student
        self.loss_fn = loss_fn
        self.exp_name = exp_name
        self.device = device
        self.base_dir = base_dir
        self.alpha = alpha


    def train_epoch_teacher(self, trainloader):
        self.teacher_model.train()
        total_loss = 0
        for _, (images, labels) in enumerate(trainloader):
            images, labels = images.to(self.device), labels.to(self.device)
            self.optimizer_teacher.zero_grad()
            logits = self.teacher_model(images)[0]
            loss = self.loss_fn(logits, labels)
            loss.backward()
            self.optimizer_teacher.step()
            total_loss += loss.item() * len(images)
        return total_loss / len(trainloader.dataset)

    def train_epoch_student(self,trainloader):
        self.student_model.train()
        total_loss = 0
        total_attention_loss = 0 
        total_classification_loss = 0        
        
        for _, (images, labels) in enumerate(trainloader):
          images, labels = images.to(self.device), labels.to(self.device)
          self.optimizer_student.zero_grad()
          # obtaining logits and attention
          student_logits, student_attention = self.student_model(images, output_attentions=True)
          with torch.no_grad(): # ENSURing THAT WE DON'T TRAIN THE TEACHER.
            _, teacher_attention = self.teacher_model(images, output_attentions=True) # we only need attention

          # Compute losses
          classification_loss = self.loss_fn(student_logits, labels)


          # Here we have to be careful. The actual student attention and
          # teacher attention are found in the first index [0].
          #They should be of dimensions [64 x 4 x 17 x 17]
          #print(student_attention[0].shape)
          #print(teacher_attention[0].shape)

          student_attention_block = student_attention[0]
          teacher_attention_block = teacher_attention[0]

          #attention_loss = nn.CrossEntropyLoss()(student_attention_block, teacher_attention_block)

          #############################################################################################
          #Instead of attention_loss using CrossEntropy, lets try computing Kullback-Leibler divergence

          log_probs = torch.log(student_attention_block + 1e-10)  # Convert the student's attention block to log probabilities
          
          attention_loss = F.kl_div(log_probs, teacher_attention_block, reduction='batchmean')
          weighted_classification = self.alpha * classification_loss
          weighted_attention = (1 - self.alpha) * attention_loss

          loss = weighted_classification + weighted_attention


          # Backward and optimize
          loss.backward()
          self.optimizer_student.step()
          total_loss += loss.item() * images.size(0)
          total_classification_loss += weighted_classification.item()* images.size(0)
          total_attention_loss += weighted_attention.item()* images.size(0)

            
        total_loss = total_loss / len(trainloader.dataset)
        total_classification_loss = total_classification_loss / len(trainloader.dataset)
        total_attention_loss = total_attention_loss/len(trainloader.dataset)


        return total_loss, total_classification_loss, total_attention_loss


    def evaluate_teacher(self, testloader):
        self.teacher_model.eval()
        total_loss, correct = 0, 0
        for _, (images, labels) in enumerate(testloader):
            images, labels = images.to(self.device), labels.to(self.device)
            logits, _ = self.teacher_model(images)
            loss = self.loss_fn(logits, labels)
            total_loss += loss.item() * len(images)
            predictions = torch.argmax(logits, dim=1)
            correct += torch.sum(predictions == labels).item()
        accuracy = correct / len(testloader.dataset)
        avg_loss = total_loss / len(testloader.dataset)
        return accuracy, avg_loss

    def evaluate_student(self, testloader):
        self.student_model.eval()
        total_loss, correct = 0, 0
        for _, (images, labels) in enumerate(testloader):
            images, labels = images.to(self.device), labels.to(self.device)
            logits, _ = self.student_model(images)
            # WE COMPUTE ONLY CLASSIFICATION LOSS WHEN EVALUATING STUDENT.
            # NO DISTILLATION LOSS COMPUTED HERE.
            loss = self.loss_fn(logits, labels)
            total_loss += loss.item() * len(images)
            predictions = torch.argmax(logits, dim=1)
            correct += torch.sum(predictions == labels).item()
        accuracy = correct / len(testloader.dataset)
        avg_loss = total_loss / len(testloader.dataset)
        return accuracy, avg_loss




    def train_teacher(self, config, trainloader,testloader, epochs,save_model_every_n_epochs=0):

        train_losses, test_losses, accuracies = [], [], []
        self.teacher_model.train()
        for i in range(epochs):
          train_loss = self.train_epoch_teacher(trainloader)
          accuracy, test_loss = self.evaluate_teacher(testloader)
          train_losses.append(train_loss)
          test_losses.append(test_loss)
          accuracies.append(accuracy)
          print(f"Epoch: {i+1}, Teacher Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}")

          # Save checkpoint if required
          if save_model_every_n_epochs > 0 and (i+1) % save_model_every_n_epochs == 0:
              save_checkpoint(self.exp_name, self.teacher_model, f"epoch_{i+1}", base_dir=self.base_dir)
              print(f'\tSaving checkpoint at epoch {i+1}')

        # Save the final model and experiment details at the end of training
        save_experiment(self.exp_name, config, self.teacher_model, train_losses, test_losses, accuracies, base_dir=self.base_dir)
        print(f'teacher_final and experiment details saved under {self.exp_name}')

# THIS IS THE ACTUAL NOVELTY. WE EMPLEMENT THIS FROM SCRATCH.
    def train_student(self, config, trainloader,testloader, epochs,save_model_every_n_epochs):
        train_losses, test_losses, attention_losses, accuracies = [], [], [],[]
        self.student_model.train()
        for i in range(epochs):
          train_loss, classification_loss, attention_loss = self.train_epoch_student(trainloader)
          accuracy, test_loss = self.evaluate_student(testloader)
          train_losses.append(train_loss)
          test_losses.append(test_loss)
          attention_losses.append(attention_loss)
          accuracies.append(accuracy)
          print(f"Epoch: {i+1}, Student Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f},Attention Loss: {attention_loss:.4f}, Accuracy: {accuracy:.4f}")

          # Save checkpoint if required
          if save_model_every_n_epochs > 0 and (i+1) % save_model_every_n_epochs == 0:
              save_checkpoint(self.exp_name, self.student_model, f"epoch_{i+1}", base_dir=self.base_dir)
              print(f'\tSaving checkpoint at epoch {i+1}')

        # Save the final model and experiment details at the end of training
        save_experiment(self.exp_name, config, self.student_model, train_losses, test_losses, accuracies, base_dir=self.base_dir)
        print(f'student_final and experiment details saved under {self.exp_name}')

# Training a Normal Student

In [42]:
config = {
    "patch_size": 8,  # Input image size: 32x32 -> 8x8 patches
    "embed_dim": 64,
    "num_hidden_layers": 3,
    "num_attention_heads": 4,
    "hidden_dim": 64 * 8, ## 32 * embed_dim -> i think this helped increase accuracy on CIFAR100
    "dropout_val": 0.1,
    "attention_probs_dropout_prob": 0.1,
    "initializer_range": 0.02,
    "image_size": 32,
    "num_classes": 100, # num_classes of CIFAR100
    "num_channels": 3,
    "qkv_bias": True,
    "use_faster_attention": True,
    "attention_block_index": 1  # Specify the block index you want to observe
}

# These are not hard constraints, but are used to prevent misconfigurations
assert config["embed_dim"] % config["num_attention_heads"] == 0
#assert config['hidden_dim'] == 4 * config['embed_dim']
assert config['image_size'] % config['patch_size'] == 0

model = ViTForClassification(config)

# Configuration parameters
exp_name = "ViT_CIFAR100"
batch_size = 64
epochs = 100
lr = 0.001
device = "cuda" if torch.cuda.is_available() else "cpu"
save_model_every_n_epochs = 5

trainloader, testloader, _ = modified_prepare_data(batch_size)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.001)
loss_fn = nn.CrossEntropyLoss()
trainer = Modified_Trainer(model, optimizer, loss_fn, "ViT_Experiment", device,config)
trainer.train(trainloader, testloader, epochs,save_model_every_n_epochs, output_attentions=False)

Files already downloaded and verified
Files already downloaded and verified




Epoch: 1, Train loss: 4.0657, Test loss: 3.7949, Accuracy: 0.1090




Epoch: 2, Train loss: 3.6932, Test loss: 3.5254, Accuracy: 0.1493




Epoch: 3, Train loss: 3.4912, Test loss: 3.4028, Accuracy: 0.1774




Epoch: 4, Train loss: 3.3702, Test loss: 3.2302, Accuracy: 0.2133




Epoch: 5, Train loss: 3.2719, Test loss: 3.1702, Accuracy: 0.2234
	Saving checkpoint at epoch 5




Epoch: 6, Train loss: 3.1941, Test loss: 3.1046, Accuracy: 0.2403




Epoch: 7, Train loss: 3.1395, Test loss: 3.0515, Accuracy: 0.2435




Epoch: 8, Train loss: 3.0692, Test loss: 2.9617, Accuracy: 0.2636




Epoch: 9, Train loss: 3.0084, Test loss: 2.9285, Accuracy: 0.2633




Epoch: 10, Train loss: 2.9524, Test loss: 2.8958, Accuracy: 0.2744
	Saving checkpoint at epoch 10




Epoch: 11, Train loss: 2.8883, Test loss: 2.8353, Accuracy: 0.2914




Epoch: 12, Train loss: 2.8339, Test loss: 2.7948, Accuracy: 0.2977




Epoch: 13, Train loss: 2.7762, Test loss: 2.7421, Accuracy: 0.3123




Epoch: 14, Train loss: 2.7272, Test loss: 2.7051, Accuracy: 0.3202




Epoch: 15, Train loss: 2.6788, Test loss: 2.6888, Accuracy: 0.3228
	Saving checkpoint at epoch 15




Epoch: 16, Train loss: 2.6370, Test loss: 2.6440, Accuracy: 0.3317




Epoch: 17, Train loss: 2.5998, Test loss: 2.6527, Accuracy: 0.3326




Epoch: 18, Train loss: 2.5645, Test loss: 2.6486, Accuracy: 0.3375




Epoch: 19, Train loss: 2.5291, Test loss: 2.6032, Accuracy: 0.3454




Epoch: 20, Train loss: 2.4911, Test loss: 2.5987, Accuracy: 0.3456
	Saving checkpoint at epoch 20




Epoch: 21, Train loss: 2.4545, Test loss: 2.5835, Accuracy: 0.3508




Epoch: 22, Train loss: 2.4312, Test loss: 2.5711, Accuracy: 0.3541




Epoch: 23, Train loss: 2.3980, Test loss: 2.5708, Accuracy: 0.3540




Epoch: 24, Train loss: 2.3664, Test loss: 2.5752, Accuracy: 0.3589




Epoch: 25, Train loss: 2.3432, Test loss: 2.5209, Accuracy: 0.3693
	Saving checkpoint at epoch 25




Epoch: 26, Train loss: 2.3121, Test loss: 2.5491, Accuracy: 0.3603




Epoch: 27, Train loss: 2.2710, Test loss: 2.5185, Accuracy: 0.3702




Epoch: 28, Train loss: 2.2593, Test loss: 2.5427, Accuracy: 0.3646




Epoch: 29, Train loss: 2.2176, Test loss: 2.5052, Accuracy: 0.3727




Epoch: 30, Train loss: 2.1919, Test loss: 2.5002, Accuracy: 0.3809
	Saving checkpoint at epoch 30




Epoch: 31, Train loss: 2.1679, Test loss: 2.5177, Accuracy: 0.3793




Epoch: 32, Train loss: 2.1415, Test loss: 2.5342, Accuracy: 0.3700




Epoch: 33, Train loss: 2.1130, Test loss: 2.4876, Accuracy: 0.3787




Epoch: 34, Train loss: 2.0939, Test loss: 2.4981, Accuracy: 0.3819




Epoch: 35, Train loss: 2.0687, Test loss: 2.4635, Accuracy: 0.3898
	Saving checkpoint at epoch 35




Epoch: 36, Train loss: 2.0528, Test loss: 2.4845, Accuracy: 0.3863




Epoch: 37, Train loss: 2.0239, Test loss: 2.5481, Accuracy: 0.3731




Epoch: 38, Train loss: 2.0170, Test loss: 2.4919, Accuracy: 0.3807




Epoch: 39, Train loss: 1.9872, Test loss: 2.5000, Accuracy: 0.3872




Epoch: 40, Train loss: 1.9716, Test loss: 2.4851, Accuracy: 0.3934
	Saving checkpoint at epoch 40




Epoch: 41, Train loss: 1.9566, Test loss: 2.5106, Accuracy: 0.3866




Epoch: 42, Train loss: 1.9368, Test loss: 2.5093, Accuracy: 0.3892




Epoch: 43, Train loss: 1.9118, Test loss: 2.5274, Accuracy: 0.3918




Epoch: 44, Train loss: 1.8972, Test loss: 2.5258, Accuracy: 0.3883




Epoch: 45, Train loss: 1.8858, Test loss: 2.5260, Accuracy: 0.3859
	Saving checkpoint at epoch 45




Epoch: 46, Train loss: 1.8647, Test loss: 2.5382, Accuracy: 0.3918




Epoch: 47, Train loss: 1.8503, Test loss: 2.5132, Accuracy: 0.3920




Epoch: 48, Train loss: 1.8304, Test loss: 2.5601, Accuracy: 0.3920




Epoch: 49, Train loss: 1.8198, Test loss: 2.5416, Accuracy: 0.3884




Epoch: 50, Train loss: 1.8094, Test loss: 2.5622, Accuracy: 0.3901
	Saving checkpoint at epoch 50




Epoch: 51, Train loss: 1.7797, Test loss: 2.5335, Accuracy: 0.3985




Epoch: 52, Train loss: 1.7764, Test loss: 2.5190, Accuracy: 0.3976




Epoch: 53, Train loss: 1.7486, Test loss: 2.5831, Accuracy: 0.3934




Epoch: 54, Train loss: 1.7427, Test loss: 2.5410, Accuracy: 0.3970




Epoch: 55, Train loss: 1.7314, Test loss: 2.5841, Accuracy: 0.3928
	Saving checkpoint at epoch 55




Epoch: 56, Train loss: 1.7205, Test loss: 2.5855, Accuracy: 0.3969




Epoch: 57, Train loss: 1.7091, Test loss: 2.5539, Accuracy: 0.4029




Epoch: 58, Train loss: 1.6939, Test loss: 2.5722, Accuracy: 0.3945




Epoch: 59, Train loss: 1.6827, Test loss: 2.5809, Accuracy: 0.3984




Epoch: 60, Train loss: 1.6544, Test loss: 2.6013, Accuracy: 0.3944
	Saving checkpoint at epoch 60




Epoch: 61, Train loss: 1.6618, Test loss: 2.6034, Accuracy: 0.3952




Epoch: 62, Train loss: 1.6456, Test loss: 2.5708, Accuracy: 0.4051




Epoch: 63, Train loss: 1.6307, Test loss: 2.5749, Accuracy: 0.4043




Epoch: 64, Train loss: 1.6256, Test loss: 2.5899, Accuracy: 0.4004




Epoch: 65, Train loss: 1.6092, Test loss: 2.5893, Accuracy: 0.3987
	Saving checkpoint at epoch 65




Epoch: 66, Train loss: 1.6021, Test loss: 2.5887, Accuracy: 0.4023




Epoch: 67, Train loss: 1.5887, Test loss: 2.6099, Accuracy: 0.3977




Epoch: 68, Train loss: 1.5873, Test loss: 2.6255, Accuracy: 0.4051




Epoch: 69, Train loss: 1.5671, Test loss: 2.6089, Accuracy: 0.4036




Epoch: 70, Train loss: 1.5620, Test loss: 2.5963, Accuracy: 0.4020
	Saving checkpoint at epoch 70




Epoch: 71, Train loss: 1.5457, Test loss: 2.5928, Accuracy: 0.4073




Epoch: 72, Train loss: 1.5354, Test loss: 2.6511, Accuracy: 0.4003




Epoch: 73, Train loss: 1.5343, Test loss: 2.6431, Accuracy: 0.4058




Epoch: 74, Train loss: 1.5195, Test loss: 2.6405, Accuracy: 0.3982




Epoch: 75, Train loss: 1.5024, Test loss: 2.7033, Accuracy: 0.3946
	Saving checkpoint at epoch 75




Epoch: 76, Train loss: 1.4946, Test loss: 2.6722, Accuracy: 0.4028




Epoch: 77, Train loss: 1.4824, Test loss: 2.7256, Accuracy: 0.3992




Epoch: 78, Train loss: 1.4793, Test loss: 2.6452, Accuracy: 0.4097




Epoch: 79, Train loss: 1.4701, Test loss: 2.6705, Accuracy: 0.4035




Epoch: 80, Train loss: 1.4648, Test loss: 2.7031, Accuracy: 0.4007
	Saving checkpoint at epoch 80




Epoch: 81, Train loss: 1.4530, Test loss: 2.6903, Accuracy: 0.4098




Epoch: 82, Train loss: 1.4577, Test loss: 2.6565, Accuracy: 0.4062




Epoch: 83, Train loss: 1.4301, Test loss: 2.6875, Accuracy: 0.4063




Epoch: 84, Train loss: 1.4321, Test loss: 2.7060, Accuracy: 0.4049




Epoch: 85, Train loss: 1.4173, Test loss: 2.6832, Accuracy: 0.4062
	Saving checkpoint at epoch 85




Epoch: 86, Train loss: 1.4176, Test loss: 2.6894, Accuracy: 0.4090




Epoch: 87, Train loss: 1.4025, Test loss: 2.6847, Accuracy: 0.4075




Epoch: 88, Train loss: 1.3977, Test loss: 2.7357, Accuracy: 0.4085




Epoch: 89, Train loss: 1.3855, Test loss: 2.7273, Accuracy: 0.4040




Epoch: 90, Train loss: 1.3824, Test loss: 2.7366, Accuracy: 0.4044
	Saving checkpoint at epoch 90




Epoch: 91, Train loss: 1.3713, Test loss: 2.6946, Accuracy: 0.4047




Epoch: 92, Train loss: 1.3700, Test loss: 2.7587, Accuracy: 0.4067




Epoch: 93, Train loss: 1.3670, Test loss: 2.7551, Accuracy: 0.4110




Epoch: 94, Train loss: 1.3487, Test loss: 2.7384, Accuracy: 0.4087




Epoch: 95, Train loss: 1.3518, Test loss: 2.7741, Accuracy: 0.4063
	Saving checkpoint at epoch 95




Epoch: 96, Train loss: 1.3392, Test loss: 2.7724, Accuracy: 0.4056




Epoch: 97, Train loss: 1.3379, Test loss: 2.7147, Accuracy: 0.4081




Epoch: 98, Train loss: 1.3303, Test loss: 2.7509, Accuracy: 0.4074




Epoch: 99, Train loss: 1.3194, Test loss: 2.8220, Accuracy: 0.3973




Epoch: 100, Train loss: 1.3212, Test loss: 2.7363, Accuracy: 0.4102
	Saving checkpoint at epoch 100
Final model and experiment details saved under ViT_Experiment


In [43]:
config = {
    "patch_size": 8,  # Input image size: 32x32 -> 8x8 patches
    "embed_dim": 64,
    "num_hidden_layers": 4,
    "num_attention_heads": 4,
    "hidden_dim": 64 * 16, ## 32 * embed_dim -> i think this helped increase accuracy on CIFAR100
    "dropout_val": 0.1,
    "attention_probs_dropout_prob": 0.1,
    "initializer_range": 0.02,
    "image_size": 32,
    "num_classes": 100, # num_classes of CIFAR100
    "num_channels": 3,
    "qkv_bias": True,
    "use_faster_attention": True,
    "attention_block_index": 1  # Specify the block index you want to observe
}

# These are not hard constraints, but are used to prevent misconfigurations
assert config["embed_dim"] % config["num_attention_heads"] == 0
#assert config['hidden_dim'] == 4 * config['embed_dim']
assert config['image_size'] % config['patch_size'] == 0

model = ViTForClassification(config)

# Configuration parameters
exp_name = "ViT_CIFAR100"
batch_size = 64
epochs = 100
lr = 0.001
device = "cuda" if torch.cuda.is_available() else "cpu"
save_model_every_n_epochs = 5

trainloader, testloader, _ = modified_prepare_data(batch_size)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.001)
loss_fn = nn.CrossEntropyLoss()
trainer = Modified_Trainer(model, optimizer, loss_fn, "ViT_Experiment", device,config)
trainer.train(trainloader, testloader, epochs,save_model_every_n_epochs, output_attentions=False)

Files already downloaded and verified
Files already downloaded and verified




Epoch: 1, Train loss: 4.0672, Test loss: 3.8169, Accuracy: 0.1013




Epoch: 2, Train loss: 3.6923, Test loss: 3.5790, Accuracy: 0.1506




Epoch: 3, Train loss: 3.4947, Test loss: 3.3695, Accuracy: 0.1877




Epoch: 4, Train loss: 3.3640, Test loss: 3.2689, Accuracy: 0.2011




Epoch: 5, Train loss: 3.2746, Test loss: 3.1715, Accuracy: 0.2235
	Saving checkpoint at epoch 5




Epoch: 6, Train loss: 3.1988, Test loss: 3.0970, Accuracy: 0.2353




Epoch: 7, Train loss: 3.1333, Test loss: 3.0636, Accuracy: 0.2455




Epoch: 8, Train loss: 3.0653, Test loss: 2.9902, Accuracy: 0.2616




Epoch: 9, Train loss: 3.0281, Test loss: 2.9288, Accuracy: 0.2731




Epoch: 10, Train loss: 2.9811, Test loss: 2.9019, Accuracy: 0.2752
	Saving checkpoint at epoch 10




Epoch: 11, Train loss: 2.9374, Test loss: 2.8626, Accuracy: 0.2860




Epoch: 12, Train loss: 2.8944, Test loss: 2.8000, Accuracy: 0.2960




Epoch: 13, Train loss: 2.8542, Test loss: 2.7907, Accuracy: 0.3037




Epoch: 14, Train loss: 2.8130, Test loss: 2.7567, Accuracy: 0.3058




Epoch: 15, Train loss: 2.7736, Test loss: 2.7094, Accuracy: 0.3165
	Saving checkpoint at epoch 15




Epoch: 16, Train loss: 2.7469, Test loss: 2.7071, Accuracy: 0.3198




Epoch: 17, Train loss: 2.7113, Test loss: 2.7079, Accuracy: 0.3212




Epoch: 18, Train loss: 2.6834, Test loss: 2.6582, Accuracy: 0.3280




Epoch: 19, Train loss: 2.6539, Test loss: 2.6009, Accuracy: 0.3397




Epoch: 20, Train loss: 2.6358, Test loss: 2.6253, Accuracy: 0.3418
	Saving checkpoint at epoch 20




Epoch: 21, Train loss: 2.6128, Test loss: 2.5854, Accuracy: 0.3486




Epoch: 22, Train loss: 2.5852, Test loss: 2.5804, Accuracy: 0.3496




Epoch: 23, Train loss: 2.5695, Test loss: 2.5408, Accuracy: 0.3542




Epoch: 24, Train loss: 2.5431, Test loss: 2.5452, Accuracy: 0.3559




Epoch: 25, Train loss: 2.5195, Test loss: 2.5225, Accuracy: 0.3612
	Saving checkpoint at epoch 25




Epoch: 26, Train loss: 2.5057, Test loss: 2.5168, Accuracy: 0.3626




Epoch: 27, Train loss: 2.4785, Test loss: 2.5227, Accuracy: 0.3574




Epoch: 28, Train loss: 2.4588, Test loss: 2.4828, Accuracy: 0.3682




Epoch: 29, Train loss: 2.4388, Test loss: 2.4786, Accuracy: 0.3676




Epoch: 30, Train loss: 2.4174, Test loss: 2.4505, Accuracy: 0.3704
	Saving checkpoint at epoch 30




Epoch: 31, Train loss: 2.4001, Test loss: 2.4277, Accuracy: 0.3835




Epoch: 32, Train loss: 2.3776, Test loss: 2.4708, Accuracy: 0.3739




Epoch: 33, Train loss: 2.3635, Test loss: 2.3952, Accuracy: 0.3883




Epoch: 34, Train loss: 2.3463, Test loss: 2.4215, Accuracy: 0.3812




Epoch: 35, Train loss: 2.3275, Test loss: 2.4038, Accuracy: 0.3872
	Saving checkpoint at epoch 35




Epoch: 36, Train loss: 2.3121, Test loss: 2.4230, Accuracy: 0.3841




Epoch: 37, Train loss: 2.2967, Test loss: 2.3781, Accuracy: 0.3914




Epoch: 38, Train loss: 2.2875, Test loss: 2.3638, Accuracy: 0.3936




Epoch: 39, Train loss: 2.2600, Test loss: 2.3543, Accuracy: 0.3994




Epoch: 40, Train loss: 2.2514, Test loss: 2.3583, Accuracy: 0.3952
	Saving checkpoint at epoch 40




Epoch: 41, Train loss: 2.2384, Test loss: 2.3355, Accuracy: 0.3993




Epoch: 42, Train loss: 2.2242, Test loss: 2.3389, Accuracy: 0.4054




Epoch: 43, Train loss: 2.2063, Test loss: 2.3202, Accuracy: 0.4050




Epoch: 44, Train loss: 2.1915, Test loss: 2.3100, Accuracy: 0.4067




Epoch: 45, Train loss: 2.1830, Test loss: 2.3029, Accuracy: 0.4104
	Saving checkpoint at epoch 45




Epoch: 46, Train loss: 2.1644, Test loss: 2.2986, Accuracy: 0.4126




Epoch: 47, Train loss: 2.1491, Test loss: 2.3208, Accuracy: 0.4097




Epoch: 48, Train loss: 2.1376, Test loss: 2.2960, Accuracy: 0.4098




Epoch: 49, Train loss: 2.1244, Test loss: 2.2889, Accuracy: 0.4152




Epoch: 50, Train loss: 2.1142, Test loss: 2.2996, Accuracy: 0.4132
	Saving checkpoint at epoch 50




Epoch: 51, Train loss: 2.1073, Test loss: 2.2735, Accuracy: 0.4189




Epoch: 52, Train loss: 2.0854, Test loss: 2.2829, Accuracy: 0.4198




Epoch: 53, Train loss: 2.0800, Test loss: 2.2971, Accuracy: 0.4146




Epoch: 54, Train loss: 2.0767, Test loss: 2.2722, Accuracy: 0.4216




Epoch: 55, Train loss: 2.0552, Test loss: 2.2715, Accuracy: 0.4185
	Saving checkpoint at epoch 55




Epoch: 56, Train loss: 2.0485, Test loss: 2.2825, Accuracy: 0.4190




Epoch: 57, Train loss: 2.0405, Test loss: 2.2770, Accuracy: 0.4140




Epoch: 58, Train loss: 2.0216, Test loss: 2.2647, Accuracy: 0.4251




Epoch: 59, Train loss: 2.0244, Test loss: 2.2559, Accuracy: 0.4249




Epoch: 60, Train loss: 2.0133, Test loss: 2.2576, Accuracy: 0.4237
	Saving checkpoint at epoch 60




Epoch: 61, Train loss: 1.9969, Test loss: 2.2370, Accuracy: 0.4258




Epoch: 62, Train loss: 1.9876, Test loss: 2.2453, Accuracy: 0.4294




Epoch: 63, Train loss: 1.9752, Test loss: 2.2310, Accuracy: 0.4296




Epoch: 64, Train loss: 1.9607, Test loss: 2.2376, Accuracy: 0.4301




Epoch: 65, Train loss: 1.9542, Test loss: 2.2553, Accuracy: 0.4315
	Saving checkpoint at epoch 65




Epoch: 66, Train loss: 1.9535, Test loss: 2.2303, Accuracy: 0.4297




Epoch: 67, Train loss: 1.9370, Test loss: 2.2001, Accuracy: 0.4402




Epoch: 68, Train loss: 1.9256, Test loss: 2.2240, Accuracy: 0.4348




Epoch: 69, Train loss: 1.9199, Test loss: 2.1822, Accuracy: 0.4396




Epoch: 70, Train loss: 1.9174, Test loss: 2.2093, Accuracy: 0.4380
	Saving checkpoint at epoch 70




Epoch: 71, Train loss: 1.8968, Test loss: 2.2135, Accuracy: 0.4317




Epoch: 72, Train loss: 1.8911, Test loss: 2.2029, Accuracy: 0.4357




Epoch: 73, Train loss: 1.8846, Test loss: 2.1786, Accuracy: 0.4412




Epoch: 74, Train loss: 1.8749, Test loss: 2.2064, Accuracy: 0.4343




Epoch: 75, Train loss: 1.8643, Test loss: 2.2083, Accuracy: 0.4409
	Saving checkpoint at epoch 75




Epoch: 76, Train loss: 1.8575, Test loss: 2.1903, Accuracy: 0.4419




Epoch: 77, Train loss: 1.8532, Test loss: 2.2363, Accuracy: 0.4412




Epoch: 78, Train loss: 1.8396, Test loss: 2.2063, Accuracy: 0.4409




Epoch: 79, Train loss: 1.8337, Test loss: 2.2096, Accuracy: 0.4424




Epoch: 80, Train loss: 1.8197, Test loss: 2.2410, Accuracy: 0.4363
	Saving checkpoint at epoch 80




Epoch: 81, Train loss: 1.8187, Test loss: 2.1809, Accuracy: 0.4467




Epoch: 82, Train loss: 1.8086, Test loss: 2.2112, Accuracy: 0.4390




Epoch: 83, Train loss: 1.8027, Test loss: 2.1912, Accuracy: 0.4504




Epoch: 84, Train loss: 1.7968, Test loss: 2.1920, Accuracy: 0.4480




Epoch: 85, Train loss: 1.7854, Test loss: 2.1731, Accuracy: 0.4466
	Saving checkpoint at epoch 85




Epoch: 86, Train loss: 1.7784, Test loss: 2.2089, Accuracy: 0.4423




Epoch: 87, Train loss: 1.7700, Test loss: 2.2005, Accuracy: 0.4546




Epoch: 88, Train loss: 1.7580, Test loss: 2.1823, Accuracy: 0.4466




Epoch: 89, Train loss: 1.7483, Test loss: 2.1929, Accuracy: 0.4453




Epoch: 90, Train loss: 1.7444, Test loss: 2.2175, Accuracy: 0.4500
	Saving checkpoint at epoch 90




Epoch: 91, Train loss: 1.7385, Test loss: 2.1881, Accuracy: 0.4484




Epoch: 92, Train loss: 1.7272, Test loss: 2.2280, Accuracy: 0.4431




Epoch: 93, Train loss: 1.7282, Test loss: 2.1906, Accuracy: 0.4502




Epoch: 94, Train loss: 1.7259, Test loss: 2.1790, Accuracy: 0.4553




Epoch: 95, Train loss: 1.7144, Test loss: 2.1948, Accuracy: 0.4512
	Saving checkpoint at epoch 95




Epoch: 96, Train loss: 1.7128, Test loss: 2.1939, Accuracy: 0.4523




Epoch: 97, Train loss: 1.7093, Test loss: 2.1784, Accuracy: 0.4470




Epoch: 98, Train loss: 1.6959, Test loss: 2.1991, Accuracy: 0.4538




Epoch: 99, Train loss: 1.6915, Test loss: 2.2086, Accuracy: 0.4525




Epoch: 100, Train loss: 1.6776, Test loss: 2.2055, Accuracy: 0.4478
	Saving checkpoint at epoch 100
Final model and experiment details saved under ViT_Experiment


In [41]:
config_teacher = {
    "patch_size": 8,  # Input image size: 32x32 -> 8x8 patches
    "embed_dim": 64,
    "num_hidden_layers": 4,
    "num_attention_heads": 4,
    "hidden_dim": 64 * 16, ## 32 * embed_dim -> i think this helped increase accuracy on CIFAR100
    "dropout_val": 0.1,
    "attention_probs_dropout_prob": 0.1,
    "initializer_range": 0.02,
    "image_size": 32,
    "num_classes": 100, # num_classes of CIFAR100
    "num_channels": 3,
    "qkv_bias": True,
    "use_faster_attention": True,
    "attention_block_index": 1  # Specify the block index you want to observe
}


config_student = {
    "patch_size": 8,  # Input image size: 32x32 -> 8x8 patches
    "embed_dim": 64,
    "num_hidden_layers": 3,
    "num_attention_heads": 4,
    "hidden_dim": 64 * 8, ## 32 * embed_dim -> i think this helped increase accuracy on CIFAR100
    "dropout_val": 0.1,
    "attention_probs_dropout_prob": 0.1,
    "initializer_range": 0.02,
    "image_size": 32,
    "num_classes": 100, # num_classes of CIFAR100
    "num_channels": 3,
    "qkv_bias": True,
    "use_faster_attention": True,
    "attention_block_index": 1  # Specify the block index you want to observe
}




exp_name = "ViT_Student_teacher"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


teacher_model = ViTForClassification(config_teacher)
student_model = ViTForClassification(config_student)

optimizer_teacher = optim.Adam(teacher_model.parameters(), lr=0.001)
optimizer_student = optim.Adam(student_model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

save_model_every_n_epochs = 5
epoch_teacher = 100  # Adjust the number of epochs as needed
epoch_student = 100  # Adjust the number of epochs as needed
batch_size = 64  # Adjust the batch size as needed
trainloader, testloader, _ = modified_prepare_data(batch_size)


alpha = 0.5 # THIS DEFINES THE WEIGHTED LOSS. IF ALPHA IS LESS THAN 0.5, THEN MORE EMPHASIS ON ATTENTION LOSS. 




trainer = TS_Trainer(teacher_model, student_model, optimizer_teacher, optimizer_student, loss_fn,exp_name, device,alpha, base_dir="experiments")
trainer.train_teacher(config_teacher, trainloader,testloader, epoch_teacher,save_model_every_n_epochs)  # Train teacher for 100 epochs
trainer.train_student(config_student,trainloader,testloader, epoch_student,save_model_every_n_epochs)  # Train student for 100 epochs

Files already downloaded and verified
Files already downloaded and verified




Epoch: 1, Teacher Train loss: 4.0779, Test loss: 3.8299, Accuracy: 0.1022




Epoch: 2, Teacher Train loss: 3.7134, Test loss: 3.5582, Accuracy: 0.1488




Epoch: 3, Teacher Train loss: 3.5013, Test loss: 3.3577, Accuracy: 0.1895




Epoch: 4, Teacher Train loss: 3.3784, Test loss: 3.2533, Accuracy: 0.2022




Epoch: 5, Teacher Train loss: 3.2749, Test loss: 3.1874, Accuracy: 0.2224
	Saving checkpoint at epoch 5




Epoch: 6, Teacher Train loss: 3.1999, Test loss: 3.0991, Accuracy: 0.2392




Epoch: 7, Teacher Train loss: 3.1368, Test loss: 3.0509, Accuracy: 0.2441




Epoch: 8, Teacher Train loss: 3.0734, Test loss: 2.9985, Accuracy: 0.2603




Epoch: 9, Teacher Train loss: 3.0229, Test loss: 2.9280, Accuracy: 0.2779




Epoch: 10, Teacher Train loss: 2.9735, Test loss: 2.8811, Accuracy: 0.2839
	Saving checkpoint at epoch 10




Epoch: 11, Teacher Train loss: 2.9282, Test loss: 2.8705, Accuracy: 0.2821




Epoch: 12, Teacher Train loss: 2.8861, Test loss: 2.8065, Accuracy: 0.2936




Epoch: 13, Teacher Train loss: 2.8425, Test loss: 2.7553, Accuracy: 0.3131




Epoch: 14, Teacher Train loss: 2.8092, Test loss: 2.7604, Accuracy: 0.3055




Epoch: 15, Teacher Train loss: 2.7764, Test loss: 2.7440, Accuracy: 0.3123
	Saving checkpoint at epoch 15




Epoch: 16, Teacher Train loss: 2.7439, Test loss: 2.7001, Accuracy: 0.3231




Epoch: 17, Teacher Train loss: 2.7156, Test loss: 2.6553, Accuracy: 0.3325




Epoch: 18, Teacher Train loss: 2.6947, Test loss: 2.6346, Accuracy: 0.3380




Epoch: 19, Teacher Train loss: 2.6547, Test loss: 2.6616, Accuracy: 0.3269




Epoch: 20, Teacher Train loss: 2.6388, Test loss: 2.5750, Accuracy: 0.3464
	Saving checkpoint at epoch 20




Epoch: 21, Teacher Train loss: 2.6030, Test loss: 2.5804, Accuracy: 0.3473




Epoch: 22, Teacher Train loss: 2.5831, Test loss: 2.5686, Accuracy: 0.3538




Epoch: 23, Teacher Train loss: 2.5572, Test loss: 2.5501, Accuracy: 0.3498




Epoch: 24, Teacher Train loss: 2.5420, Test loss: 2.5324, Accuracy: 0.3586




Epoch: 25, Teacher Train loss: 2.5083, Test loss: 2.5223, Accuracy: 0.3631
	Saving checkpoint at epoch 25




Epoch: 26, Teacher Train loss: 2.4960, Test loss: 2.4954, Accuracy: 0.3655




Epoch: 27, Teacher Train loss: 2.4731, Test loss: 2.4885, Accuracy: 0.3664




Epoch: 28, Teacher Train loss: 2.4498, Test loss: 2.4627, Accuracy: 0.3759




Epoch: 29, Teacher Train loss: 2.4346, Test loss: 2.4515, Accuracy: 0.3707




Epoch: 30, Teacher Train loss: 2.4144, Test loss: 2.4755, Accuracy: 0.3746
	Saving checkpoint at epoch 30




Epoch: 31, Teacher Train loss: 2.3891, Test loss: 2.3908, Accuracy: 0.3851




Epoch: 32, Teacher Train loss: 2.3678, Test loss: 2.4046, Accuracy: 0.3826




Epoch: 33, Teacher Train loss: 2.3510, Test loss: 2.3965, Accuracy: 0.3852




Epoch: 34, Teacher Train loss: 2.3338, Test loss: 2.3528, Accuracy: 0.3994




Epoch: 35, Teacher Train loss: 2.3184, Test loss: 2.3797, Accuracy: 0.3913
	Saving checkpoint at epoch 35




Epoch: 36, Teacher Train loss: 2.3010, Test loss: 2.3247, Accuracy: 0.4045




Epoch: 37, Teacher Train loss: 2.2885, Test loss: 2.3568, Accuracy: 0.3985




Epoch: 38, Teacher Train loss: 2.2672, Test loss: 2.3247, Accuracy: 0.4031




Epoch: 39, Teacher Train loss: 2.2524, Test loss: 2.3050, Accuracy: 0.4122




Epoch: 40, Teacher Train loss: 2.2388, Test loss: 2.3047, Accuracy: 0.4125
	Saving checkpoint at epoch 40




Epoch: 41, Teacher Train loss: 2.2128, Test loss: 2.3203, Accuracy: 0.4113




Epoch: 42, Teacher Train loss: 2.2044, Test loss: 2.2955, Accuracy: 0.4162




Epoch: 43, Teacher Train loss: 2.1979, Test loss: 2.2860, Accuracy: 0.4188




Epoch: 44, Teacher Train loss: 2.1760, Test loss: 2.2954, Accuracy: 0.4153




Epoch: 45, Teacher Train loss: 2.1633, Test loss: 2.2896, Accuracy: 0.4145
	Saving checkpoint at epoch 45




Epoch: 46, Teacher Train loss: 2.1460, Test loss: 2.2862, Accuracy: 0.4151




Epoch: 47, Teacher Train loss: 2.1439, Test loss: 2.2514, Accuracy: 0.4233




Epoch: 48, Teacher Train loss: 2.1173, Test loss: 2.2557, Accuracy: 0.4239




Epoch: 49, Teacher Train loss: 2.1112, Test loss: 2.2595, Accuracy: 0.4172




Epoch: 50, Teacher Train loss: 2.0947, Test loss: 2.2688, Accuracy: 0.4203
	Saving checkpoint at epoch 50




Epoch: 51, Teacher Train loss: 2.0862, Test loss: 2.2348, Accuracy: 0.4229




Epoch: 52, Teacher Train loss: 2.0666, Test loss: 2.2589, Accuracy: 0.4229




Epoch: 53, Teacher Train loss: 2.0544, Test loss: 2.2368, Accuracy: 0.4275




Epoch: 54, Teacher Train loss: 2.0428, Test loss: 2.2478, Accuracy: 0.4298




Epoch: 55, Teacher Train loss: 2.0371, Test loss: 2.2231, Accuracy: 0.4284
	Saving checkpoint at epoch 55




Epoch: 56, Teacher Train loss: 2.0209, Test loss: 2.2423, Accuracy: 0.4253




Epoch: 57, Teacher Train loss: 2.0163, Test loss: 2.2244, Accuracy: 0.4343




Epoch: 58, Teacher Train loss: 2.0025, Test loss: 2.2028, Accuracy: 0.4371




Epoch: 59, Teacher Train loss: 1.9913, Test loss: 2.2332, Accuracy: 0.4324




Epoch: 60, Teacher Train loss: 1.9799, Test loss: 2.2539, Accuracy: 0.4242
	Saving checkpoint at epoch 60




Epoch: 61, Teacher Train loss: 1.9702, Test loss: 2.2170, Accuracy: 0.4339




Epoch: 62, Teacher Train loss: 1.9516, Test loss: 2.1925, Accuracy: 0.4345




Epoch: 63, Teacher Train loss: 1.9462, Test loss: 2.2223, Accuracy: 0.4289




Epoch: 64, Teacher Train loss: 1.9389, Test loss: 2.1955, Accuracy: 0.4334




Epoch: 65, Teacher Train loss: 1.9305, Test loss: 2.2046, Accuracy: 0.4358
	Saving checkpoint at epoch 65




Epoch: 66, Teacher Train loss: 1.9179, Test loss: 2.2045, Accuracy: 0.4370




Epoch: 67, Teacher Train loss: 1.9034, Test loss: 2.2096, Accuracy: 0.4384




Epoch: 68, Teacher Train loss: 1.8974, Test loss: 2.1880, Accuracy: 0.4431




Epoch: 69, Teacher Train loss: 1.8907, Test loss: 2.1879, Accuracy: 0.4432




Epoch: 70, Teacher Train loss: 1.8770, Test loss: 2.1887, Accuracy: 0.4454
	Saving checkpoint at epoch 70




Epoch: 71, Teacher Train loss: 1.8702, Test loss: 2.1887, Accuracy: 0.4437




Epoch: 72, Teacher Train loss: 1.8618, Test loss: 2.2027, Accuracy: 0.4412




Epoch: 73, Teacher Train loss: 1.8539, Test loss: 2.1773, Accuracy: 0.4499




Epoch: 74, Teacher Train loss: 1.8458, Test loss: 2.2049, Accuracy: 0.4385




Epoch: 75, Teacher Train loss: 1.8375, Test loss: 2.1797, Accuracy: 0.4463
	Saving checkpoint at epoch 75




Epoch: 76, Teacher Train loss: 1.8251, Test loss: 2.1717, Accuracy: 0.4488




Epoch: 77, Teacher Train loss: 1.8179, Test loss: 2.1697, Accuracy: 0.4522




Epoch: 78, Teacher Train loss: 1.8132, Test loss: 2.1750, Accuracy: 0.4462




Epoch: 79, Teacher Train loss: 1.8044, Test loss: 2.1711, Accuracy: 0.4485




Epoch: 80, Teacher Train loss: 1.7981, Test loss: 2.1622, Accuracy: 0.4508
	Saving checkpoint at epoch 80




Epoch: 81, Teacher Train loss: 1.7924, Test loss: 2.2067, Accuracy: 0.4460




Epoch: 82, Teacher Train loss: 1.7780, Test loss: 2.1500, Accuracy: 0.4553




Epoch: 83, Teacher Train loss: 1.7751, Test loss: 2.1616, Accuracy: 0.4558




Epoch: 84, Teacher Train loss: 1.7653, Test loss: 2.1681, Accuracy: 0.4518




Epoch: 85, Teacher Train loss: 1.7573, Test loss: 2.1450, Accuracy: 0.4524
	Saving checkpoint at epoch 85




Epoch: 86, Teacher Train loss: 1.7496, Test loss: 2.1812, Accuracy: 0.4528




Epoch: 87, Teacher Train loss: 1.7479, Test loss: 2.1568, Accuracy: 0.4584




Epoch: 88, Teacher Train loss: 1.7340, Test loss: 2.1688, Accuracy: 0.4572




Epoch: 89, Teacher Train loss: 1.7187, Test loss: 2.1671, Accuracy: 0.4603




Epoch: 90, Teacher Train loss: 1.7181, Test loss: 2.1776, Accuracy: 0.4552
	Saving checkpoint at epoch 90




Epoch: 91, Teacher Train loss: 1.7147, Test loss: 2.1690, Accuracy: 0.4590




Epoch: 92, Teacher Train loss: 1.7058, Test loss: 2.1769, Accuracy: 0.4553




Epoch: 93, Teacher Train loss: 1.7026, Test loss: 2.1650, Accuracy: 0.4537




Epoch: 94, Teacher Train loss: 1.6945, Test loss: 2.1693, Accuracy: 0.4579




Epoch: 95, Teacher Train loss: 1.6838, Test loss: 2.1817, Accuracy: 0.4608
	Saving checkpoint at epoch 95




Epoch: 96, Teacher Train loss: 1.6754, Test loss: 2.1333, Accuracy: 0.4658




Epoch: 97, Teacher Train loss: 1.6709, Test loss: 2.1841, Accuracy: 0.4564




Epoch: 98, Teacher Train loss: 1.6736, Test loss: 2.1537, Accuracy: 0.4553




Epoch: 99, Teacher Train loss: 1.6647, Test loss: 2.1588, Accuracy: 0.4600




Epoch: 100, Teacher Train loss: 1.6579, Test loss: 2.1690, Accuracy: 0.4618
	Saving checkpoint at epoch 100
teacher_final and experiment details saved under ViT_Student_teacher




attention loss for one batch  tensor(131.7242, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(6.9839e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 1, Student Train loss: 69.4073, Test loss: 3.7283,Attention Loss: 131.7242, Accuracy: 0.1118




attention loss for one batch  tensor(130.5045, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(7.6135e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 2, Student Train loss: 68.6522, Test loss: 3.4756,Attention Loss: 130.5045, Accuracy: 0.1607




attention loss for one batch  tensor(135.3957, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(7.3947e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 3, Student Train loss: 68.5715, Test loss: 3.2780,Attention Loss: 135.3957, Accuracy: 0.1977




attention loss for one batch  tensor(134.6920, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(5.6499e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 4, Student Train loss: 68.4803, Test loss: 3.1232,Attention Loss: 134.6920, Accuracy: 0.2302




attention loss for one batch  tensor(137.6198, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(6.5532e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 5, Student Train loss: 68.3727, Test loss: 2.9958,Attention Loss: 137.6198, Accuracy: 0.2487
	Saving checkpoint at epoch 5




attention loss for one batch  tensor(132.9446, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(6.3427e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 6, Student Train loss: 68.3270, Test loss: 2.9072,Attention Loss: 132.9446, Accuracy: 0.2704




attention loss for one batch  tensor(134.1730, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(6.6757e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 7, Student Train loss: 68.2760, Test loss: 2.8210,Attention Loss: 134.1730, Accuracy: 0.2886




attention loss for one batch  tensor(131.0259, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(5.3524e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 8, Student Train loss: 68.2565, Test loss: 2.7427,Attention Loss: 131.0259, Accuracy: 0.3046




attention loss for one batch  tensor(134.3921, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(5.1320e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 9, Student Train loss: 68.2196, Test loss: 2.6441,Attention Loss: 134.3921, Accuracy: 0.3287




attention loss for one batch  tensor(132.4533, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.1413e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 10, Student Train loss: 68.1639, Test loss: 2.6090,Attention Loss: 132.4533, Accuracy: 0.3300
	Saving checkpoint at epoch 10




attention loss for one batch  tensor(131.9235, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.4575e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 11, Student Train loss: 68.0902, Test loss: 2.5638,Attention Loss: 131.9235, Accuracy: 0.3425




attention loss for one batch  tensor(137.3901, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.0375e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 12, Student Train loss: 68.1310, Test loss: 2.5273,Attention Loss: 137.3901, Accuracy: 0.3482




attention loss for one batch  tensor(131.6419, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.4908e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 13, Student Train loss: 68.0774, Test loss: 2.4640,Attention Loss: 131.6419, Accuracy: 0.3647




attention loss for one batch  tensor(141.4880, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(5.0782e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 14, Student Train loss: 68.0276, Test loss: 2.4401,Attention Loss: 141.4880, Accuracy: 0.3662




attention loss for one batch  tensor(133.5088, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.9928e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 15, Student Train loss: 68.0131, Test loss: 2.4089,Attention Loss: 133.5088, Accuracy: 0.3798
	Saving checkpoint at epoch 15




attention loss for one batch  tensor(133.6148, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(5.1500e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 16, Student Train loss: 68.0449, Test loss: 2.4166,Attention Loss: 133.6148, Accuracy: 0.3667




attention loss for one batch  tensor(132.8211, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.5514e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 17, Student Train loss: 68.0045, Test loss: 2.3854,Attention Loss: 132.8211, Accuracy: 0.3837




attention loss for one batch  tensor(135.5042, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(5.1479e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 18, Student Train loss: 67.9978, Test loss: 2.3608,Attention Loss: 135.5042, Accuracy: 0.3901




attention loss for one batch  tensor(135.3502, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.8790e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 19, Student Train loss: 67.9913, Test loss: 2.3575,Attention Loss: 135.3502, Accuracy: 0.3922




attention loss for one batch  tensor(133.5346, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.3167e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 20, Student Train loss: 67.8933, Test loss: 2.3209,Attention Loss: 133.5346, Accuracy: 0.3989
	Saving checkpoint at epoch 20




attention loss for one batch  tensor(126.5842, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(5.7390e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 21, Student Train loss: 67.9377, Test loss: 2.3286,Attention Loss: 126.5842, Accuracy: 0.3922




attention loss for one batch  tensor(139.6100, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(5.4225e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 22, Student Train loss: 67.9435, Test loss: 2.3015,Attention Loss: 139.6100, Accuracy: 0.4054




attention loss for one batch  tensor(128.9316, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.8913e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 23, Student Train loss: 67.9171, Test loss: 2.3242,Attention Loss: 128.9316, Accuracy: 0.3995




attention loss for one batch  tensor(134.4232, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.5195e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 24, Student Train loss: 67.8171, Test loss: 2.3095,Attention Loss: 134.4232, Accuracy: 0.3989




attention loss for one batch  tensor(139.1593, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.9619e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 25, Student Train loss: 67.9010, Test loss: 2.3036,Attention Loss: 139.1593, Accuracy: 0.4062
	Saving checkpoint at epoch 25




attention loss for one batch  tensor(132.4450, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.5857e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 26, Student Train loss: 67.8675, Test loss: 2.2819,Attention Loss: 132.4450, Accuracy: 0.4093




attention loss for one batch  tensor(129.7392, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.6998e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 27, Student Train loss: 67.8662, Test loss: 2.2951,Attention Loss: 129.7392, Accuracy: 0.4057




attention loss for one batch  tensor(138.6998, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.6631e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 28, Student Train loss: 67.8419, Test loss: 2.2801,Attention Loss: 138.6998, Accuracy: 0.4104




attention loss for one batch  tensor(130.9919, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.1398e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 29, Student Train loss: 67.8650, Test loss: 2.3051,Attention Loss: 130.9919, Accuracy: 0.4137




attention loss for one batch  tensor(135.4554, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.3902e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 30, Student Train loss: 67.8298, Test loss: 2.2861,Attention Loss: 135.4554, Accuracy: 0.4149
	Saving checkpoint at epoch 30




attention loss for one batch  tensor(132.3778, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.4913e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 31, Student Train loss: 67.8285, Test loss: 2.3108,Attention Loss: 132.3778, Accuracy: 0.4137




attention loss for one batch  tensor(129.4956, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.1088e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 32, Student Train loss: 67.8111, Test loss: 2.3274,Attention Loss: 129.4956, Accuracy: 0.4088




attention loss for one batch  tensor(127.3759, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.8856e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 33, Student Train loss: 67.8061, Test loss: 2.2688,Attention Loss: 127.3759, Accuracy: 0.4238




attention loss for one batch  tensor(132.2744, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.6002e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 34, Student Train loss: 67.7554, Test loss: 2.2590,Attention Loss: 132.2744, Accuracy: 0.4230




attention loss for one batch  tensor(133.9398, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.3105e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 35, Student Train loss: 67.8387, Test loss: 2.2782,Attention Loss: 133.9398, Accuracy: 0.4211
	Saving checkpoint at epoch 35




attention loss for one batch  tensor(139.1645, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.3413e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 36, Student Train loss: 67.7190, Test loss: 2.2722,Attention Loss: 139.1645, Accuracy: 0.4237




attention loss for one batch  tensor(133.3110, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.2672e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 37, Student Train loss: 67.8396, Test loss: 2.2690,Attention Loss: 133.3110, Accuracy: 0.4216




attention loss for one batch  tensor(136.0365, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.5625e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 38, Student Train loss: 67.7174, Test loss: 2.2782,Attention Loss: 136.0365, Accuracy: 0.4245




attention loss for one batch  tensor(138.6964, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.8783e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 39, Student Train loss: 67.7840, Test loss: 2.2767,Attention Loss: 138.6964, Accuracy: 0.4297




attention loss for one batch  tensor(138.8379, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.0212e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 40, Student Train loss: 67.7475, Test loss: 2.2261,Attention Loss: 138.8379, Accuracy: 0.4383
	Saving checkpoint at epoch 40




attention loss for one batch  tensor(131.7893, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.1827e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 41, Student Train loss: 67.7833, Test loss: 2.2831,Attention Loss: 131.7893, Accuracy: 0.4271




attention loss for one batch  tensor(127.2844, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.5062e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 42, Student Train loss: 67.7026, Test loss: 2.2810,Attention Loss: 127.2844, Accuracy: 0.4242




attention loss for one batch  tensor(133.1751, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.1401e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 43, Student Train loss: 67.7435, Test loss: 2.3106,Attention Loss: 133.1751, Accuracy: 0.4231




attention loss for one batch  tensor(132.3475, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.4023e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 44, Student Train loss: 67.6820, Test loss: 2.2870,Attention Loss: 132.3475, Accuracy: 0.4238




attention loss for one batch  tensor(136.5151, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.2367e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 45, Student Train loss: 67.7493, Test loss: 2.3089,Attention Loss: 136.5151, Accuracy: 0.4247
	Saving checkpoint at epoch 45




attention loss for one batch  tensor(141.4107, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.2404e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 46, Student Train loss: 67.6439, Test loss: 2.2887,Attention Loss: 141.4107, Accuracy: 0.4321




attention loss for one batch  tensor(130.7925, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.9953e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 47, Student Train loss: 67.7006, Test loss: 2.3103,Attention Loss: 130.7925, Accuracy: 0.4257




attention loss for one batch  tensor(141.0922, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.0492e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 48, Student Train loss: 67.6971, Test loss: 2.3047,Attention Loss: 141.0922, Accuracy: 0.4295




attention loss for one batch  tensor(126.4341, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.7846e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 49, Student Train loss: 67.6690, Test loss: 2.2881,Attention Loss: 126.4341, Accuracy: 0.4294




attention loss for one batch  tensor(131.6034, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.0670e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 50, Student Train loss: 67.6673, Test loss: 2.3115,Attention Loss: 131.6034, Accuracy: 0.4264
	Saving checkpoint at epoch 50




attention loss for one batch  tensor(129.7968, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.7076e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 51, Student Train loss: 67.6059, Test loss: 2.3128,Attention Loss: 129.7968, Accuracy: 0.4242




attention loss for one batch  tensor(133.3567, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.8855e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 52, Student Train loss: 67.6783, Test loss: 2.3246,Attention Loss: 133.3567, Accuracy: 0.4277




attention loss for one batch  tensor(132.3594, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.5796e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 53, Student Train loss: 67.5914, Test loss: 2.3277,Attention Loss: 132.3594, Accuracy: 0.4284




attention loss for one batch  tensor(132.0371, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.2334e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 54, Student Train loss: 67.7086, Test loss: 2.3397,Attention Loss: 132.0371, Accuracy: 0.4275




attention loss for one batch  tensor(132.1279, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.5437e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 55, Student Train loss: 67.6472, Test loss: 2.3073,Attention Loss: 132.1279, Accuracy: 0.4363
	Saving checkpoint at epoch 55




attention loss for one batch  tensor(142.0049, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.7028e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 56, Student Train loss: 67.6584, Test loss: 2.3241,Attention Loss: 142.0049, Accuracy: 0.4329




attention loss for one batch  tensor(127.5218, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.7992e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 57, Student Train loss: 67.6016, Test loss: 2.3481,Attention Loss: 127.5218, Accuracy: 0.4292




attention loss for one batch  tensor(128.0128, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.6028e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 58, Student Train loss: 67.6319, Test loss: 2.3437,Attention Loss: 128.0128, Accuracy: 0.4315




attention loss for one batch  tensor(133.6204, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.0786e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 59, Student Train loss: 67.6519, Test loss: 2.3236,Attention Loss: 133.6204, Accuracy: 0.4287




attention loss for one batch  tensor(131.6717, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.3888e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 60, Student Train loss: 67.5937, Test loss: 2.3701,Attention Loss: 131.6717, Accuracy: 0.4274
	Saving checkpoint at epoch 60




attention loss for one batch  tensor(131.4660, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.3122e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 61, Student Train loss: 67.5587, Test loss: 2.3343,Attention Loss: 131.4660, Accuracy: 0.4336




attention loss for one batch  tensor(139.3632, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.3750e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 62, Student Train loss: 67.6385, Test loss: 2.3747,Attention Loss: 139.3632, Accuracy: 0.4287




attention loss for one batch  tensor(136.9496, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.5909e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 63, Student Train loss: 67.5840, Test loss: 2.3573,Attention Loss: 136.9496, Accuracy: 0.4322




attention loss for one batch  tensor(126.6603, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.4625e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 64, Student Train loss: 67.5361, Test loss: 2.3779,Attention Loss: 126.6603, Accuracy: 0.4297




attention loss for one batch  tensor(133.0187, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.0580e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 65, Student Train loss: 67.5486, Test loss: 2.3685,Attention Loss: 133.0187, Accuracy: 0.4294
	Saving checkpoint at epoch 65




attention loss for one batch  tensor(131.4362, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.7630e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 66, Student Train loss: 67.6083, Test loss: 2.4081,Attention Loss: 131.4362, Accuracy: 0.4263




attention loss for one batch  tensor(133.8648, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.2211e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 67, Student Train loss: 67.5988, Test loss: 2.4411,Attention Loss: 133.8648, Accuracy: 0.4251




attention loss for one batch  tensor(129.8342, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.0572e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 68, Student Train loss: 67.5573, Test loss: 2.4020,Attention Loss: 129.8342, Accuracy: 0.4299




attention loss for one batch  tensor(128.2063, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.7430e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 69, Student Train loss: 67.5420, Test loss: 2.4474,Attention Loss: 128.2063, Accuracy: 0.4264




attention loss for one batch  tensor(140.5090, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.5020e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 70, Student Train loss: 67.5899, Test loss: 2.4125,Attention Loss: 140.5090, Accuracy: 0.4303
	Saving checkpoint at epoch 70




attention loss for one batch  tensor(136.0667, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.7743e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 71, Student Train loss: 67.5644, Test loss: 2.4094,Attention Loss: 136.0667, Accuracy: 0.4333




attention loss for one batch  tensor(128.9294, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.5897e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 72, Student Train loss: 67.5731, Test loss: 2.4043,Attention Loss: 128.9294, Accuracy: 0.4292




attention loss for one batch  tensor(134.8619, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.0519e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 73, Student Train loss: 67.5571, Test loss: 2.3958,Attention Loss: 134.8619, Accuracy: 0.4341




attention loss for one batch  tensor(133.6550, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.2894e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 74, Student Train loss: 67.5385, Test loss: 2.3935,Attention Loss: 133.6550, Accuracy: 0.4398




attention loss for one batch  tensor(133.6062, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.1047e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 75, Student Train loss: 67.5222, Test loss: 2.4135,Attention Loss: 133.6062, Accuracy: 0.4334
	Saving checkpoint at epoch 75




attention loss for one batch  tensor(133.4333, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.6020e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 76, Student Train loss: 67.5178, Test loss: 2.4100,Attention Loss: 133.4333, Accuracy: 0.4357




attention loss for one batch  tensor(123.6429, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.0390e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 77, Student Train loss: 67.5644, Test loss: 2.4639,Attention Loss: 123.6429, Accuracy: 0.4262




attention loss for one batch  tensor(135.7704, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.7906e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 78, Student Train loss: 67.5451, Test loss: 2.4464,Attention Loss: 135.7704, Accuracy: 0.4337




attention loss for one batch  tensor(133.1336, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.6592e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 79, Student Train loss: 67.4977, Test loss: 2.4459,Attention Loss: 133.1336, Accuracy: 0.4337




attention loss for one batch  tensor(130.8564, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.1391e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 80, Student Train loss: 67.4573, Test loss: 2.4735,Attention Loss: 130.8564, Accuracy: 0.4235
	Saving checkpoint at epoch 80




attention loss for one batch  tensor(134.8419, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(4.6915e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 81, Student Train loss: 67.4373, Test loss: 2.4485,Attention Loss: 134.8419, Accuracy: 0.4314




attention loss for one batch  tensor(133.4696, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.2863e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 82, Student Train loss: 67.4712, Test loss: 2.4676,Attention Loss: 133.4696, Accuracy: 0.4291




attention loss for one batch  tensor(136.1960, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.2518e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 83, Student Train loss: 67.5484, Test loss: 2.4741,Attention Loss: 136.1960, Accuracy: 0.4344




attention loss for one batch  tensor(139.4144, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.8602e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 84, Student Train loss: 67.4702, Test loss: 2.4567,Attention Loss: 139.4144, Accuracy: 0.4322




attention loss for one batch  tensor(138.2286, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.5663e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 85, Student Train loss: 67.4967, Test loss: 2.4861,Attention Loss: 138.2286, Accuracy: 0.4379
	Saving checkpoint at epoch 85




attention loss for one batch  tensor(130.5190, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.7890e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 86, Student Train loss: 67.4935, Test loss: 2.4650,Attention Loss: 130.5190, Accuracy: 0.4355




attention loss for one batch  tensor(134.6438, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.8186e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 87, Student Train loss: 67.5256, Test loss: 2.4985,Attention Loss: 134.6438, Accuracy: 0.4301




attention loss for one batch  tensor(126.7763, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.2061e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 88, Student Train loss: 67.5403, Test loss: 2.4932,Attention Loss: 126.7763, Accuracy: 0.4297




attention loss for one batch  tensor(125.4995, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.8007e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 89, Student Train loss: 67.4536, Test loss: 2.5109,Attention Loss: 125.4995, Accuracy: 0.4278




attention loss for one batch  tensor(132.2080, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(1.9735e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 90, Student Train loss: 67.5062, Test loss: 2.5241,Attention Loss: 132.2080, Accuracy: 0.4289
	Saving checkpoint at epoch 90




attention loss for one batch  tensor(133.7233, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.0286e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 91, Student Train loss: 67.5226, Test loss: 2.5504,Attention Loss: 133.7233, Accuracy: 0.4214




attention loss for one batch  tensor(135.6661, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.5567e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 92, Student Train loss: 67.4847, Test loss: 2.4925,Attention Loss: 135.6661, Accuracy: 0.4253




attention loss for one batch  tensor(131.8742, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.4450e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 93, Student Train loss: 67.4807, Test loss: 2.5031,Attention Loss: 131.8742, Accuracy: 0.4326




attention loss for one batch  tensor(139.1153, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(1.7026e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 94, Student Train loss: 67.5021, Test loss: 2.5541,Attention Loss: 139.1153, Accuracy: 0.4291




attention loss for one batch  tensor(130.9589, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.9254e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 95, Student Train loss: 67.4971, Test loss: 2.5404,Attention Loss: 130.9589, Accuracy: 0.4303
	Saving checkpoint at epoch 95




attention loss for one batch  tensor(130.3418, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.5873e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 96, Student Train loss: 67.3943, Test loss: 2.5195,Attention Loss: 130.3418, Accuracy: 0.4318




attention loss for one batch  tensor(135.8808, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.9209e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 97, Student Train loss: 67.4412, Test loss: 2.5722,Attention Loss: 135.8808, Accuracy: 0.4285




attention loss for one batch  tensor(136.8189, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.1918e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 98, Student Train loss: 67.4354, Test loss: 2.5788,Attention Loss: 136.8189, Accuracy: 0.4289




attention loss for one batch  tensor(136.9891, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(3.7113e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 99, Student Train loss: 67.5555, Test loss: 2.5298,Attention Loss: 136.9891, Accuracy: 0.4343




attention loss for one batch  tensor(135.2263, device='cuda:1', grad_fn=<DivBackward0>)
classification loss for one batch  tensor(2.4673e-05, device='cuda:1', grad_fn=<DivBackward0>)
Epoch: 100, Student Train loss: 67.4934, Test loss: 2.5801,Attention Loss: 135.2263, Accuracy: 0.4322
	Saving checkpoint at epoch 100
student_final and experiment details saved under ViT_Student_teacher
