# Teacher Model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import os
from torch.autograd import grad
from sklearn.model_selection import train_test_split

In [None]:
# Load the updated interaction array and cluster labels
interaction_array = np.load('normalized_interaction_array.npy').astype('float32')  #  array with BERT embeddings
cluster_labels = np.load('user_cluster_labels_with_embeddings.npy').astype('int')  #  cluster labels
trainX = torch.tensor(np.load('trainX.npy'))
trainY = torch.tensor(np.load('trainY.npy'))
testX = torch.tensor(np.load('testX.npy'))
testY = torch.tensor(np.load('testY.npy'))

# Dataset and DataLoader
train_dataset = TensorDataset(trainX, trainY)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

# Parameters
latent_dim = 200
num_classes = 26  # Number of clusters
output_size = interaction_array.shape[1]  # Adjusted for the new interaction array dimensions
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# 1. Generator
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes, output_size):
        super(Generator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, 128)

        self.model = nn.Sequential(
            nn.Linear(latent_dim + 128, 2048),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(2048),

            nn.Linear(2048, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),

            nn.Linear(1024, output_size),
            nn.Tanh()#
        )

    def forward(self, noise, labels):
        label_embed = self.label_embedding(labels)
        input_data = torch.cat((noise, label_embed), dim=1)
        return self.model(input_data)


# 2. Discriminator
class Discriminator(nn.Module):
    def __init__(self, output_size, num_classes):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, 128)

        self.model = nn.Sequential(
            nn.Linear(output_size + 128, 2048),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(2048),
            nn.Dropout(0.4),

            nn.Linear(2048, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.4),

            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, data, labels):
        label_embed = self.label_embedding(labels)
        input_data = torch.cat((data, label_embed), dim=1)
        return self.model(input_data)


# Initialize models
generator = Generator(latent_dim, num_classes, output_size).to(device)
discriminator = Discriminator(output_size, num_classes).to(device)

torch.save(generator, 'teacher_generator_model.pth')
torch.save(discriminator, 'teacher_discriminator_model.pth')


In [None]:
# Train CGAN with Balanced Training Strategy
def train_cgan(generator, discriminator, train_loader, latent_dim, num_classes, epochs=750, checkpoint_path='teacher_cgan_checkpoint_new.pth'):
    best_g_loss = float('inf')
    best_d_accuracy = 0.0
    best_g_accuracy = 0.0
    best_epoch = -1

    # Adjust discriminator and generator training frequency
    k_d = 1  # Train discriminator every k_d epochs
    k_g = 2  # Train generator k_g times per epoch

    # Load checkpoint if exists
    start_epoch, best_g_loss = load_checkpoint(generator, discriminator, optimizer_g, optimizer_d, checkpoint_path)

    for epoch in range(start_epoch + 1, epochs):
        generator.train()
        discriminator.train()

        for real_data, real_labels in train_loader:
            batch_size = real_data.size(0)
            real_data, real_labels = real_data.to(device), real_labels.to(device)

            # ----------------------
            # Train Discriminator
            # ----------------------
            if epoch % k_d == 0:  # Train discriminator less frequently
                optimizer_d.zero_grad()

                # Real data
                real_target = torch.full((batch_size, 1), 0.9, device=device)  # Label smoothing for real data
                d_real = discriminator(real_data, real_labels)

                # Fake data
                noise = torch.randn(batch_size, latent_dim, device=device)
                fake_labels = torch.randint(0, num_classes, (batch_size,), device=device)
                fake_data = generator(noise, fake_labels)
                fake_target = torch.full((batch_size, 1), -0.9, device=device)  # Label smoothing for fake data
                d_fake = discriminator(fake_data.detach(), fake_labels)

                # Total discriminator loss
                d_loss = d_loss_real + d_loss_fake + 5 * gp  # Gradient penalty weight reduced to 5
                d_loss.backward()
                optimizer_d.step()

            # ----------------------
            # Train Generator
            # ----------------------
            for _ in range(k_g):  # Train generator more frequently
                optimizer_g.zero_grad()

                noise = torch.randn(batch_size, latent_dim, device=device)
                fake_labels = torch.randint(0, num_classes, (batch_size,), device=device)
                fake_data = generator(noise, fake_labels)

                # Generator tries to fool the discriminator
                g_loss = -torch.mean(discriminator(fake_data, fake_labels))  # Negative Wasserstein loss
                g_loss.backward()
                optimizer_g.step()


train_cgan(generator, discriminator, train_loader, latent_dim, num_classes, epochs=750)

Epoch 3/750, D Loss: 0.2391, G Loss: -0.9373, D Accuracy: 54.20%, G Accuracy: 93.70%
Best discriminator saved at epoch 3 with accuracy: 54.20%
Best generator saved at epoch 3 with accuracy: 93.70%
Checkpoint saved at epoch 3
Epoch 4/750, D Loss: -0.0673, G Loss: -0.9527, D Accuracy: 53.75%, G Accuracy: 90.60%
Checkpoint saved at epoch 4
Epoch 5/750, D Loss: -0.0859, G Loss: -0.9526, D Accuracy: 53.80%, G Accuracy: 92.10%
Checkpoint saved at epoch 5
Epoch 6/750, D Loss: 0.0579, G Loss: -0.9062, D Accuracy: 53.65%, G Accuracy: 92.20%
Checkpoint saved at epoch 6
Epoch 7/750, D Loss: -0.1376, G Loss: -0.8906, D Accuracy: 54.45%, G Accuracy: 92.00%
Best discriminator saved at epoch 7 with accuracy: 54.45%
Checkpoint saved at epoch 7
Epoch 8/750, D Loss: -0.2236, G Loss: -0.9062, D Accuracy: 54.20%, G Accuracy: 91.50%
Checkpoint saved at epoch 8
Epoch 9/750, D Loss: -0.2338, G Loss: -0.9372, D Accuracy: 53.75%, G Accuracy: 93.20%
Checkpoint saved at epoch 9
Epoch 10/750, D Loss: 0.0011, G Lo

# Student Model

In [None]:
class StudentGenerator(nn.Module):
    def __init__(self, latent_dim, num_classes, output_size):
        super(StudentGenerator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, 60)

        self.model = nn.Sequential(
            nn.Linear(latent_dim + 60, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),

            nn.Linear(256, output_size),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        label_embed = self.label_embedding(labels)
        input_data = torch.cat((noise, label_embed), dim=1)
        return self.model(input_data)


class StudentDiscriminator(nn.Module):
    def __init__(self, output_size, num_classes):
        super(StudentDiscriminator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, 60)

        self.model = nn.Sequential(
            nn.Linear(output_size + 60, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            nn.Dropout(0.3),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, data, labels):
        label_embed = self.label_embedding(labels)
        input_data = torch.cat((data, label_embed), dim=1)
        return self.model(input_data)


# Student Models
student_generator = StudentGenerator(latent_dim, num_classes, output_size).to(device)
student_discriminator = StudentDiscriminator(output_size, num_classes).to(device)

torch.save(student_generator, 'student_generator_model.pth')
torch.save(student_discriminator, 'student_discriminator_model.pth')


In [None]:
def train_student_cgan(teacher_generator,teacher_discriminator,student_generator,student_discriminator,train_loader,latent_dim,num_classes, epochs=500,checkpoint_path="student_cgan_checkpoint.pth"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    best_d_accuracy = 0.0
    best_g_accuracy = 0.0
    best_epoch = -1

    # Adjust training frequencies for student models
    k_d = 1  # Train discriminator every k_d epochs
    k_g = 2  # Train generator k_g times per epoch

    # Optimizers for the student models
    optimizer_sg = optim.Adam(student_generator.parameters(), lr=0.00005, betas=(0.5, 0.999))
    optimizer_sd = optim.Adam(student_discriminator.parameters(), lr=0.00002, betas=(0.5, 0.999))

    # Loss function for distillation
    distillation_loss = nn.MSELoss()

    # Load checkpoint if exists
    start_epoch, best_d_accuracy, best_g_accuracy = load_student_checkpoint(student_generator, student_discriminator, optimizer_sg, optimizer_sd, checkpoint_path)

    for epoch in range(start_epoch + 1, epochs):
        student_generator.train()
        student_discriminator.train()

        for real_data, real_labels in train_loader:
            batch_size = real_data.size(0)
            real_data, real_labels = real_data.to(device), real_labels.to(device)

            # ----------------------
            # Train Student Discriminator
            # ----------------------
            if epoch % k_d == 0:  # Train discriminator less frequently
                optimizer_sd.zero_grad()

                # Real data
                d_real = student_discriminator(real_data, real_labels)
                d_loss_real = -torch.mean(d_real)

                # Fake data
                noise = torch.randn(batch_size, latent_dim, device=device)
                fake_labels = torch.randint(0, num_classes, (batch_size,), device=device)
                student_fake_data = student_generator(noise, fake_labels)
                d_fake = student_discriminator(student_fake_data.detach(), fake_labels)
                d_loss_fake = torch.mean(d_fake)

                # Gradient penalty
                gp = compute_gradient_penalty(student_discriminator, real_data, student_fake_data.detach(), real_labels)

                # Total discriminator loss
                d_loss = d_loss_real + d_loss_fake + 10 * gp  # Gradient penalty weight
                d_loss.backward()
                optimizer_sd.step()

            # ----------------------
            # Train Student Generator
            # ----------------------
           # Train Student Generator
            for _ in range(k_g):  # Train generator more frequently
                optimizer_sg.zero_grad()

                # Recompute fake data
                noise = torch.randn(batch_size, latent_dim, device=device)
                fake_labels = torch.randint(0, num_classes, (batch_size,), device=device)
                student_fake_data = student_generator(noise, fake_labels)

                # Generator adversarial loss
                student_g_fake = student_discriminator(student_fake_data, fake_labels)
                g_loss_adv = -torch.mean(student_g_fake)



train_student_cgan(teacher_generator, teacher_discriminator, student_generator, student_discriminator, train_loader, latent_dim, num_classes, epochs=500)


No checkpoint found, starting from scratch.
Epoch 1/500, D Loss: 2.0070, G Loss: -0.4669, D Accuracy: 61.00%, G Accuracy: 76.40%
Best discriminator saved at epoch 1 with accuracy: 61.00%
Best generator saved at epoch 1 with accuracy: 76.40%
Checkpoint saved at epoch 1
Epoch 2/500, D Loss: 1.3682, G Loss: -0.5340, D Accuracy: 68.35%, G Accuracy: 71.10%
Best discriminator saved at epoch 2 with accuracy: 68.35%
Reloading best generator at epoch 2...


  student_generator.load_state_dict(torch.load("best_student_generator.pth"))


Checkpoint saved at epoch 2
Epoch 3/500, D Loss: 2.4447, G Loss: -0.4897, D Accuracy: 67.75%, G Accuracy: 65.50%
Reloading best generator at epoch 3...
Checkpoint saved at epoch 3
Epoch 4/500, D Loss: 2.6902, G Loss: -0.5748, D Accuracy: 66.45%, G Accuracy: 52.20%
Reloading best generator at epoch 4...
Checkpoint saved at epoch 4
Epoch 5/500, D Loss: 2.2835, G Loss: -0.5513, D Accuracy: 79.20%, G Accuracy: 39.90%
Best discriminator saved at epoch 5 with accuracy: 79.20%
Reloading best generator at epoch 5...
Checkpoint saved at epoch 5
Epoch 6/500, D Loss: 4.5508, G Loss: -0.3765, D Accuracy: 86.80%, G Accuracy: 31.10%
Best discriminator saved at epoch 6 with accuracy: 86.80%
Reloading best generator at epoch 6...
Checkpoint saved at epoch 6
Epoch 7/500, D Loss: 2.8791, G Loss: -0.4757, D Accuracy: 70.60%, G Accuracy: 44.30%
Reloading best discriminator at epoch 7...
Reloading best generator at epoch 7...


  student_discriminator.load_state_dict(torch.load("best_student_discriminator.pth"))


Checkpoint saved at epoch 7
Epoch 8/500, D Loss: 3.1936, G Loss: -0.3251, D Accuracy: 83.75%, G Accuracy: 35.50%
Reloading best generator at epoch 8...
Checkpoint saved at epoch 8
Epoch 9/500, D Loss: 2.2916, G Loss: -0.6019, D Accuracy: 5.90%, G Accuracy: 87.60%
Best generator saved at epoch 9 with accuracy: 87.60%
Reloading best discriminator at epoch 9...
Checkpoint saved at epoch 9
Epoch 10/500, D Loss: 2.6960, G Loss: -0.6553, D Accuracy: 61.10%, G Accuracy: 70.30%
Reloading best discriminator at epoch 10...
Reloading best generator at epoch 10...
Checkpoint saved at epoch 10
Epoch 11/500, D Loss: 2.4469, G Loss: -0.6423, D Accuracy: 6.40%, G Accuracy: 88.10%
Best generator saved at epoch 11 with accuracy: 88.10%
Reloading best discriminator at epoch 11...
Checkpoint saved at epoch 11
Epoch 12/500, D Loss: 3.2216, G Loss: -0.7126, D Accuracy: 55.80%, G Accuracy: 90.10%
Best generator saved at epoch 12 with accuracy: 90.10%
Reloading best discriminator at epoch 12...
Checkpoint sav

KeyboardInterrupt: 