In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, SwinForImageClassification
from transformers import AutoFeatureExtractor
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import torch

import os
import pandas as pd
from PIL import Image
import numpy as np

import glob

from matplotlib import pyplot as plt

In [2]:
class Args():
    batch_size = 16
    epochs = 50
    patience = 5
    device = 'cuda:1'
    temperature = 5
    lambda_param = 0.5

args = Args

In [3]:
class CustomImageDataset(Dataset):
    def __init__(self, transform=None):
        """
        Args:
            annotations_file (string): Path to the csv file with annotations.
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        img_dir = '/hdd/yuchen/satdata/l4s_classification/'
        img_lst = []
        label_lst = []
        
        files = glob.glob(img_dir+'n0/*')
        for img_idx in files:
            img_lst.append(Image.open(img_idx).convert("RGB"))
            label_lst.append(0)
        print(len(files))
        files = glob.glob(img_dir+'n1/*')
        print(len(files))
        
        for img_idx in files:
            img_lst.append(Image.open(img_idx).convert("RGB"))
            label_lst.append(1)
            
        self.img_labels = label_lst
        self.img = img_lst
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        image = self.img[idx]
        label = self.img_labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label


In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to the input size expected by your model
    transforms.ToTensor(),  # Convert images to Tensor
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize images
])


In [5]:
# Initialize the dataset
custom_dataset = CustomImageDataset(transform=transform)

# Create a DataLoader
batch_size = args.batch_size
shuffle = True

# Split sizes
total_size = len(custom_dataset)
train_size = int(total_size * 0.8)
val_size = total_size - train_size

# Split the dataset
torch.manual_seed(42)  # For reproducibility
train_dataset, val_dataset = random_split(custom_dataset, [train_size, val_size])

# DataLoaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
valid_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

1567
2231


In [6]:
ct = 0
for batch in valid_loader:
    ct += 1
    if ct == 2: break

print(batch[0].shape, batch[1].shape)

torch.Size([16, 3, 224, 224]) torch.Size([16])


In [7]:
# fig, axs = plt.subplots(2, 3, figsize=(15, 10)) # Creating a 2x3 subplot

# for i in range(6):
#     img = batch[0][i].permute(1,2,0).detach().cpu().numpy()
#     axs[i//3, i%3].imshow(img)
#     axs[i//3, i%3].set_xticks([])
#     axs[i//3, i%3].set_yticks([])

# plt.tight_layout()


In [8]:
# fig.savefig('/hdd/yuchen/landslidesimg.pdf')

In [9]:
teacher_model = ViTForImageClassification.from_pretrained("google/vit-large-patch16-224")
teacher_model.classifier =  nn.Linear(in_features=1024, out_features=2, bias=True)

In [10]:
ckpt = torch.load('/hdd/yuchen/satdata/weights/fmow_pretrain.pth')['model']

teacher_model.vit.embeddings.patch_embeddings.projection.weight = torch.nn.Parameter(ckpt['patch_embed.proj.weight'])
teacher_model.vit.embeddings.patch_embeddings.projection.bias = torch.nn.Parameter(ckpt['patch_embed.proj.bias'])

for layer_num in range(24):
    teacher_model.vit.encoder.layer[layer_num].attention.attention.query.weight = torch.nn.Parameter(ckpt[f'blocks.{layer_num}.attn.qkv.weight'][0:1024])
    teacher_model.vit.encoder.layer[layer_num].attention.attention.query.bias = torch.nn.Parameter(ckpt[f'blocks.{layer_num}.attn.qkv.bias'][0:1024])
    teacher_model.vit.encoder.layer[layer_num].attention.attention.key.weight = torch.nn.Parameter(ckpt[f'blocks.{layer_num}.attn.qkv.weight'][1024:2048])
    teacher_model.vit.encoder.layer[layer_num].attention.attention.key.bias = torch.nn.Parameter(ckpt[f'blocks.{layer_num}.attn.qkv.bias'][1024:2048])
    teacher_model.vit.encoder.layer[layer_num].attention.attention.value.weight = torch.nn.Parameter(ckpt[f'blocks.{layer_num}.attn.qkv.weight'][2048:])
    teacher_model.vit.encoder.layer[layer_num].attention.attention.value.bias = torch.nn.Parameter(ckpt[f'blocks.{layer_num}.attn.qkv.bias'][2048:])
    teacher_model.vit.encoder.layer[layer_num].intermediate.dense.weight = torch.nn.Parameter(ckpt[f'blocks.{layer_num}.mlp.fc1.weight'])
    teacher_model.vit.encoder.layer[layer_num].intermediate.dense.bias = torch.nn.Parameter(ckpt[f'blocks.{layer_num}.mlp.fc1.bias'])
    teacher_model.vit.encoder.layer[layer_num].output.dense.weight = torch.nn.Parameter(ckpt[f'blocks.{layer_num}.mlp.fc2.weight'])
    teacher_model.vit.encoder.layer[layer_num].output.dense.bias = torch.nn.Parameter(ckpt[f'blocks.{layer_num}.mlp.fc2.bias'])

del ckpt

## train teacher model

In [10]:
model_path = "/hdd/yuchen/satdata/weights/teacher_model_weights_finetuneweights.pth"

device = args.device
teacher_model.to(device)
optimizer = optim.Adam(teacher_model.parameters(), lr=1e-4)

best_accur = 0
ct_patience = 0

for epoch in range(args.epochs):
    
    teacher_model.train()
    train_loss = 0
    ct = 0
    for images, labels in train_loader:
        ct +=1
        print(ct , '/', len(train_loader), end='\r')
        images, labels = images.to(device), labels.to(device)
    
        teacher_output = teacher_model(images)
        hard_loss = nn.functional.cross_entropy(teacher_output.logits, labels)
        train_loss += hard_loss
    
        optimizer.zero_grad()
        hard_loss.backward()
        optimizer.step()
        
    print(f'Epoch [{epoch}] train loss:', round(train_loss.item()/len(train_loader), 5), end = ' ')

    teacher_model.eval()
    acc_lst = []
    for images, labels in valid_loader:
        images, labels = images.to(device), labels.to(device)
        teacher_output = teacher_model(images)
    
        pred = torch.max(teacher_output.logits,axis=1).indices.cpu().numpy()
        label = labels.detach().cpu().numpy()
        acc_lst.append(len(np.where(pred==label)[0])/len(label))
        
    curr_accur = np.mean(acc_lst)
    if curr_accur>best_accur:
        ct_patience = 0
        best_accur = curr_accur
        torch.save(teacher_model.state_dict(), model_path)
    else:
        ct_patience += 1
        if ct_patience > args.patience:
            print('end')
            break
            
    print('valid accuracy:', round(curr_accur, 5))

Epoch [0] train loss: 0.67726 valid accuracy: 0.64761
Epoch [1] train loss: 0.53478 valid accuracy: 0.79122
Epoch [2] train loss: 0.48343 valid accuracy: 0.73138
Epoch [3] train loss: 0.44888 valid accuracy: 0.78856
Epoch [4] train loss: 0.40197 valid accuracy: 0.78457
Epoch [5] train loss: 0.38802 valid accuracy: 0.80319
Epoch [6] train loss: 0.36773 valid accuracy: 0.81516
Epoch [7] train loss: 0.33496 valid accuracy: 0.81649
Epoch [8] train loss: 0.3409 valid accuracy: 0.81782
Epoch [9] train loss: 0.29245 valid accuracy: 0.80718
Epoch [10] train loss: 0.25785 valid accuracy: 0.82713
Epoch [11] train loss: 0.22277 valid accuracy: 0.82846
Epoch [12] train loss: 0.19944 valid accuracy: 0.83378
Epoch [13] train loss: 0.2243 valid accuracy: 0.83378
Epoch [14] train loss: 0.18339 valid accuracy: 0.75266
Epoch [15] train loss: 0.1587 valid accuracy: 0.80585
Epoch [16] train loss: 0.12743 valid accuracy: 0.82181
Epoch [17] train loss: 0.08209 valid accuracy: 0.79654
Epoch [18] train loss: 

## train student model 

In [11]:
# model_path = "/hdd/yuchen/satdata/weights/teacher_model_weights.pth"
model_path = "/hdd/yuchen/satdata/weights/teacher_model_weights_finetuneweights.pth"
teacher_model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [7]:
# student_model = SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
student_model = SwinForImageClassification.from_pretrained("microsoft/swin-base-patch4-window7-224")
student_model.classifier = nn.Linear(in_features=1024, out_features=2, bias=True)


In [11]:
# ckpt = torch.load('/hdd/yuchen/satdata/weights/swinmae100.pth')

# student_model.swin.embeddings.patch_embeddings.projection.weight = torch.nn.Parameter(ckpt['patch_embed.proj.weight'])
# student_model.swin.embeddings.patch_embeddings.projection.bias = torch.nn.Parameter(ckpt['patch_embed.proj.bias'])
# student_model.swin.embeddings.norm.weight = torch.nn.Parameter(ckpt['patch_embed.norm.weight'])
# student_model.swin.embeddings.norm.bias = torch.nn.Parameter(ckpt['patch_embed.norm.bias'])

# for idx in [0,1,2,3]:
#     student_model.swin.encoder.layers[idx].blocks[0].layernorm_before.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.norm1.weight'])
#     student_model.swin.encoder.layers[idx].blocks[0].layernorm_before.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.norm1.bias'])

#     curr_size = student_model.swin.encoder.layers[idx].blocks[0].attention.self.query.weight.shape[0]
    
#     student_model.swin.encoder.layers[idx].blocks[0].attention.self.query.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.attn.qkv.weight'][:curr_size])
#     student_model.swin.encoder.layers[idx].blocks[0].attention.self.key.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.attn.qkv.weight'][curr_size:curr_size*2])
#     student_model.swin.encoder.layers[idx].blocks[0].attention.self.value.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.attn.qkv.weight'][curr_size*2:])
#     student_model.swin.encoder.layers[idx].blocks[0].attention.self.query.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.attn.qkv.bias'][:curr_size])
#     student_model.swin.encoder.layers[idx].blocks[0].attention.self.key.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.attn.qkv.bias'][curr_size:curr_size*2])
#     student_model.swin.encoder.layers[idx].blocks[0].attention.self.value.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.attn.qkv.bias'][curr_size*2:])
#     student_model.swin.encoder.layers[idx].blocks[0].attention.output.dense.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.attn.proj.weight'])
#     student_model.swin.encoder.layers[idx].blocks[0].attention.output.dense.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.attn.proj.bias'])
    
#     student_model.swin.encoder.layers[idx].blocks[0].layernorm_after.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.norm2.weight'])
#     student_model.swin.encoder.layers[idx].blocks[0].layernorm_after.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.norm2.bias'])
    
#     student_model.swin.encoder.layers[idx].blocks[0].intermediate.dense.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.mlp.fc1.weight'])
#     student_model.swin.encoder.layers[idx].blocks[0].intermediate.dense.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.mlp.fc1.bias'])
    
#     student_model.swin.encoder.layers[idx].blocks[0].output.dense.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.mlp.fc2.weight'])
#     student_model.swin.encoder.layers[idx].blocks[0].output.dense.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.0.mlp.fc2.bias'])
    
    
#     student_model.swin.encoder.layers[idx].blocks[1].layernorm_before.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.norm1.weight'])
#     student_model.swin.encoder.layers[idx].blocks[1].layernorm_before.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.norm1.bias'])
    
#     student_model.swin.encoder.layers[idx].blocks[1].attention.self.query.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.attn.qkv.weight'][:curr_size])
#     student_model.swin.encoder.layers[idx].blocks[1].attention.self.key.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.attn.qkv.weight'][curr_size:curr_size*2])
#     student_model.swin.encoder.layers[idx].blocks[1].attention.self.value.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.attn.qkv.weight'][curr_size*2:])
#     student_model.swin.encoder.layers[idx].blocks[1].attention.self.query.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.attn.qkv.bias'][:curr_size])
#     student_model.swin.encoder.layers[idx].blocks[1].attention.self.key.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.attn.qkv.bias'][curr_size:curr_size*2])
#     student_model.swin.encoder.layers[idx].blocks[1].attention.self.value.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.attn.qkv.bias'][curr_size*2:])
#     student_model.swin.encoder.layers[idx].blocks[1].attention.output.dense.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.attn.proj.weight'])
#     student_model.swin.encoder.layers[idx].blocks[1].attention.output.dense.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.attn.proj.bias'])
    
#     student_model.swin.encoder.layers[idx].blocks[1].layernorm_after.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.norm2.weight'])
#     student_model.swin.encoder.layers[idx].blocks[1].layernorm_after.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.norm2.bias'])
#     student_model.swin.encoder.layers[idx].blocks[1].intermediate.dense.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.mlp.fc1.weight'])
#     student_model.swin.encoder.layers[idx].blocks[1].intermediate.dense.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.mlp.fc1.bias'])
#     student_model.swin.encoder.layers[idx].blocks[1].output.dense.weight = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.mlp.fc2.weight'])
#     student_model.swin.encoder.layers[idx].blocks[1].output.dense.bias = torch.nn.Parameter(ckpt[f'layers.{idx}.blocks.1.mlp.fc2.bias'])
    
#     if idx != 3:
#         student_model.swin.encoder.layers[idx].downsample.reduction.weight =  torch.nn.Parameter(ckpt[f'layers.{idx}.downsample.reduction.weight'])
#         student_model.swin.encoder.layers[idx].downsample.norm.weight =  torch.nn.Parameter(ckpt[f'layers.{idx}.downsample.norm.weight'])
#         student_model.swin.encoder.layers[idx].downsample.norm.bias =  torch.nn.Parameter(ckpt[f'layers.{idx}.downsample.norm.bias'])

In [13]:
args.device

'cuda:1'

In [14]:
class DistillationLoss(nn.Module):
    def __init__(self, temperature=3, lambda_param=0.5):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.kl_div = nn.KLDivLoss(reduction='batchmean')
        self.lambda_param = lambda_param

    def forward(self, student_logits, teacher_logits, labels):
        soft_loss = self.kl_div(
            nn.functional.log_softmax(student_logits / self.temperature, dim=1),
            nn.functional.softmax(teacher_logits / self.temperature, dim=1),
        ) * (self.temperature ** 2)
        
        hard_loss = nn.functional.cross_entropy(student_logits, labels)
        
        loss = (1. - self.lambda_param) * hard_loss + self.lambda_param * soft_loss
        return  loss

In [15]:
student_model_path = "/hdd/yuchen/satdata/weights/student_model_weights_pretrainswin.pth"
args.lambda_param = 0.9

# Training settings
device = torch.device(args.device)
student_model.to(device)
teacher_model.to(device)
optimizer = optim.Adam(student_model.parameters(), lr=1e-4)
distillation_loss = DistillationLoss(temperature = args.temperature, lambda_param = args.lambda_param)

best_dist_loss = np.inf
patience = args.patience
ct_patience = 0

for epoch in range(args.epochs):
    student_model.train()
    train_loss = 0
    ct = 0
    for images, labels in train_loader:
        ct +=1
        print(ct , '/', len(train_loader), end='\r')
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            teacher_output = teacher_model(images).logits
                        
        student_output = student_model(images).logits
        loss = distillation_loss(student_output, teacher_output, labels)
        train_loss += loss
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f'Epoch [{epoch}] train loss:', round(train_loss.item()/len(train_loader), 5), end = ' ')

    
    teacher_model.eval()
    student_model.eval()
    
    acc_lst = []
    val_loss = 0
    for images, labels in valid_loader:
        images, labels = images.to(device), labels.to(device)
        student_output = student_model(images)

        with torch.no_grad():
            teacher_output = teacher_model(images)
    
        pred = torch.max(student_output.logits,axis=1).indices.cpu().numpy()
        label = labels.detach().cpu().numpy()
        acc_lst.append(len(np.where(pred==label)[0])/len(label))
        val_loss += distillation_loss(student_output.logits, teacher_output.logits, labels).item()
        
    curr_accur = np.mean(acc_lst)
    
    if val_loss < best_dist_loss:
        ct_patience = 0
        best_dist_loss = val_loss
        student_model_path = f"/hdd/yuchen/satdata/weights/temp/student_model_weights_{epoch}.pth"
        
        torch.save(student_model.state_dict(), student_model_path)
    else:
        ct_patience+=1
        if ct_patience > args.patience:
            print('end')
            break
    print('valid accuracy:', round(curr_accur, 5), 'dist loss:',  round(val_loss/len(valid_loader), 5))

Epoch [0] train loss: 0.96097 valid accuracy: 0.86436 dist loss: 0.54203
Epoch [1] train loss: 0.56693 valid accuracy: 0.87633 dist loss: 0.50666
Epoch [2] train loss: 0.37416 valid accuracy: 0.88165 dist loss: 0.58333
Epoch [3] train loss: 0.26402 valid accuracy: 0.87234 dist loss: 0.53506
Epoch [4] train loss: 0.19509 valid accuracy: 0.86569 dist loss: 0.5148
Epoch [5] train loss: 0.14609 valid accuracy: 0.86968 dist loss: 0.46218
Epoch [6] train loss: 0.10937 valid accuracy: 0.875 dist loss: 0.47058
Epoch [7] train loss: 0.09264 valid accuracy: 0.88165 dist loss: 0.44099
Epoch [8] train loss: 0.06385 valid accuracy: 0.87899 dist loss: 0.44042
Epoch [9] train loss: 0.05953 valid accuracy: 0.87234 dist loss: 0.44452
Epoch [10] train loss: 0.05534 valid accuracy: 0.88298 dist loss: 0.43916
Epoch [11] train loss: 0.05059 valid accuracy: 0.85771 dist loss: 0.45308
Epoch [12] train loss: 0.05579 valid accuracy: 0.88165 dist loss: 0.42758
Epoch [13] train loss: 0.05222 valid accuracy: 0.87

In [19]:
distillation_loss = DistillationLoss()

distillation_loss(student_output, teacher_output, labels)


tensor(0.8634, device='cuda:1', grad_fn=<AddBackward0>)

In [11]:
for batch in train_loader:
    break

batch[0].shape

torch.Size([32, 3, 224, 224])

In [13]:
batch[1]

tensor([3, 0, 0, 7, 7, 8, 4, 8, 7, 0, 5, 1, 6, 3, 1, 9, 0, 5, 7, 4, 5, 6, 2, 0,
        3, 4, 6, 3, 4, 5, 1, 9])

In [7]:
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F

class ImageDistilTrainer(Trainer):
    def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None,  *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.teacher = teacher_model
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.teacher.to(device)
        self.teacher.eval()
        self.temperature = temperature
        self.lambda_param = lambda_param

    def compute_loss(self, student, inputs, return_outputs=False):
        student_output = self.student(**inputs)

        with torch.no_grad():
          teacher_output = self.teacher(**inputs)

        soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

        student_target_loss = student_output.loss

        loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
        return (loss, student_output) if return_outputs else loss

In [None]:
from transformers import AutoModelForImageClassification, MobileNetV2Config, MobileNetV2ForImageClassification

training_args = TrainingArguments(
    output_dir="my-awesome-model",
    num_train_epochs=30,
    fp16=True,
    logging_dir=f"{repo_name}/logs",
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",
    push_to_hub=True,
    hub_strategy="every_save",
    hub_model_id=repo_name,
    )

num_labels = len(processed_datasets["train"].features["labels"].names)

# initialize models
teacher_model = AutoModelForImageClassification.from_pretrained(
    "merve/beans-vit-224",
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)

# training MobileNetV2 from scratch
student_config = MobileNetV2Config()
student_config.num_labels = num_labels
student_model = MobileNetV2ForImageClassification(student_config)

In [None]:
import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy": acc["accuracy"]}

In [None]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()
trainer = ImageDistilTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    training_args=training_args,
    train_dataset=processed_datasets["train"],
    eval_dataset=processed_datasets["validation"],
    data_collator=data_collator,
    tokenizer=teacher_processor,
    compute_metrics=compute_metrics,
    temperature=5,
    lambda_param=0.5
)